diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index a1390ad0be..13eb40fd60 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator): datasource_type=datasource_type, datasource_info=json.dumps(datasource_info), datasource_node_id=start_node_id, - input_data=inputs, + input_data=dict(inputs), pipeline_id=pipeline.id, created_by=user.id, ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 5bb539b7dc..e8ba2d7aab 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -274,6 +274,8 @@ class OpsTraceManager: raise ValueError("App not found") tenant_id = app.tenant_id + if trace_config_data.tracing_config is None: + raise ValueError("Tracing config cannot be None.") decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6cf6620d8d..6c818bdc8b 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -309,11 +309,12 @@ class ProviderManager: (model for model in available_models if model.model == "gpt-4"), available_models[0] ) - default_model = TenantDefaultModel() - default_model.tenant_id = tenant_id - default_model.model_type = model_type.to_origin_model_type() - default_model.provider_name = available_model.provider.provider - default_model.model_name = available_model.model + default_model = TenantDefaultModel( + tenant_id=tenant_id, + model_type=model_type.to_origin_model_type(), + provider_name=available_model.provider.provider, + model_name=available_model.model, + ) db.session.add(default_model) db.session.commit() diff --git a/api/models/account.py b/api/models/account.py index 615883f01a..b1dafed0ed 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -11,8 +11,7 @@ from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated -from models.base import TypeBase - +from .base import TypeBase from .engine import db from .types import LongText, StringUUID diff --git a/api/models/base.py b/api/models/base.py index 78a1fdc2b7..c8a5e20f25 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -5,8 +5,9 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_co from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 -from models.engine import metadata -from models.types import StringUUID + +from .engine import metadata +from .types import StringUUID class Base(DeclarativeBase): diff --git a/api/models/dataset.py b/api/models/dataset.py index 04f107f858..4bc802bb1c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -22,11 +22,10 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 -from models.base import TypeBase from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account -from .base import Base +from .base import Base, TypeBase from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index @@ -934,21 +933,25 @@ class AppDatasetJoin(TypeBase): return db.session.get(App, self.app_id) -class DatasetQuery(Base): +class DatasetQuery(TypeBase): __tablename__ = "dataset_queries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"), sa.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4())) - dataset_id = mapped_column(StringUUID, nullable=False) - content = mapped_column(LongText, nullable=False) + id: Mapped[str] = mapped_column( + StringUUID, primary_key=True, nullable=False, default=lambda: str(uuid4()), init=False + ) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False) - source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(String(255), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) + source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False + ) class DatasetKeywordTable(TypeBase): @@ -1047,12 +1050,12 @@ class TidbAuthBinding(Base): sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), sa.Index("tidb_auth_bindings_status_idx", "status"), ) - id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) + tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1148,7 +1151,7 @@ class ExternalKnowledgeApis(TypeBase): return dataset_bindings -class ExternalKnowledgeBindings(Base): +class ExternalKnowledgeBindings(TypeBase): __tablename__ = "external_knowledge_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), @@ -1158,16 +1161,18 @@ class ExternalKnowledgeBindings(Base): sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - external_knowledge_api_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - external_knowledge_id = mapped_column(String(512), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = mapped_column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + external_knowledge_api_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + external_knowledge_id: Mapped[str] = mapped_column(String(512), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None, init=False) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) @@ -1245,49 +1250,61 @@ class DatasetMetadataBinding(Base): created_by = mapped_column(StringUUID, nullable=False) -class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] +class PipelineBuiltInTemplate(TypeBase): __tablename__ = "pipeline_built_in_templates" __table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) - name = mapped_column(sa.String(255), nullable=False) - description = mapped_column(LongText, nullable=False) - chunk_structure = mapped_column(sa.String(255), nullable=False) - icon = mapped_column(sa.JSON, nullable=False) - yaml_content = mapped_column(LongText, nullable=False) - copyright = mapped_column(sa.String(255), nullable=False) - privacy_policy = mapped_column(sa.String(255), nullable=False) - position = mapped_column(sa.Integer, nullable=False) - install_count = mapped_column(sa.Integer, nullable=False, default=0) - language = mapped_column(sa.String(255), nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column(LongText, nullable=False) + chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) + icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) + copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) + language: Mapped[str] = mapped_column(sa.String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) -class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] +class PipelineCustomizedTemplate(TypeBase): __tablename__ = "pipeline_customized_templates" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) - tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(sa.String(255), nullable=False) - description = mapped_column(LongText, nullable=False) - chunk_structure = mapped_column(sa.String(255), nullable=False) - icon = mapped_column(sa.JSON, nullable=False) - position = mapped_column(sa.Integer, nullable=False) - yaml_content = mapped_column(LongText, nullable=False) - install_count = mapped_column(sa.Integer, nullable=False, default=0) - language = mapped_column(sa.String(255), nullable=False) - created_by = mapped_column(StringUUID, nullable=False) - updated_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column(LongText, nullable=False) + chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) + icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) + install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) + language: Mapped[str] = mapped_column(sa.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None, init=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -1320,34 +1337,42 @@ class Pipeline(Base): # type: ignore[name-defined] return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() -class DocumentPipelineExecutionLog(Base): +class DocumentPipelineExecutionLog(TypeBase): __tablename__ = "document_pipeline_execution_logs" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"), sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) - pipeline_id = mapped_column(StringUUID, nullable=False) - document_id = mapped_column(StringUUID, nullable=False) - datasource_type = mapped_column(sa.String(255), nullable=False) - datasource_info = mapped_column(LongText, nullable=False) - datasource_node_id = mapped_column(sa.String(255), nullable=False) - input_data = mapped_column(sa.JSON, nullable=False) - created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) + pipeline_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) + datasource_info: Mapped[str] = mapped_column(LongText, nullable=False) + datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) + input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) -class PipelineRecommendedPlugin(Base): +class PipelineRecommendedPlugin(TypeBase): __tablename__ = "pipeline_recommended_plugins" __table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),) - id = mapped_column(StringUUID, default=lambda: str(uuidv7())) - plugin_id = mapped_column(LongText, nullable=False) - provider_name = mapped_column(LongText, nullable=False) - position = mapped_column(sa.Integer, nullable=False, default=0) - active = mapped_column(sa.Boolean, nullable=False, default=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False) + plugin_id: Mapped[str] = mapped_column(LongText, nullable=False) + provider_name: Mapped[str] = mapped_column(LongText, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) diff --git a/api/models/model.py b/api/models/model.py index fb287bcea5..b0bf46e7d7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,7 +31,7 @@ from .provider_ids import GenericProviderID from .types import LongText, StringUUID if TYPE_CHECKING: - from models.workflow import Workflow + from .workflow import Workflow class DifySetup(TypeBase): @@ -1747,36 +1747,40 @@ class UploadFile(Base): self.source_url = source_url -class ApiRequest(Base): +class ApiRequest(TypeBase): __tablename__ = "api_requests" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="api_request_pkey"), sa.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - api_token_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + api_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False) path: Mapped[str] = mapped_column(String(255), nullable=False) - request = mapped_column(LongText, nullable=True) - response = mapped_column(LongText, nullable=True) + request: Mapped[str | None] = mapped_column(LongText, nullable=True) + response: Mapped[str | None] = mapped_column(LongText, nullable=True) ip: Mapped[str] = mapped_column(String(255), nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) -class MessageChain(Base): +class MessageChain(TypeBase): __tablename__ = "message_chains" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_chain_pkey"), sa.Index("message_chain_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - message_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - input = mapped_column(LongText, nullable=True) - output = mapped_column(LongText, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) + input: Mapped[str | None] = mapped_column(LongText, nullable=True) + output: Mapped[str | None] = mapped_column(LongText, nullable=True) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False + ) class MessageAgentThought(Base): @@ -1956,22 +1960,28 @@ class TagBinding(TypeBase): ) -class TraceAppConfig(Base): +class TraceAppConfig(TypeBase): __tablename__ = "trace_app_config" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), sa.Index("trace_app_config_app_id_idx", "app_id"), ) - id = mapped_column(StringUUID, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(String(255), nullable=True) - tracing_config = mapped_column(sa.JSON, nullable=True) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) + tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, + ) + is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) @property def tracing_config_dict(self) -> dict[str, Any]: diff --git a/api/models/provider.py b/api/models/provider.py index 5f54676389..a840a483ab 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -166,21 +166,23 @@ class ProviderModel(TypeBase): return credential.encrypted_config if credential else None -class TenantDefaultModel(Base): +class TenantDefaultModel(TypeBase): __tablename__ = "tenant_default_models" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) diff --git a/api/models/source.py b/api/models/source.py index ed5d30b48a..f093048c00 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -6,8 +6,7 @@ import sqlalchemy as sa from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column -from models.base import TypeBase - +from .base import TypeBase from .types import AdjustedJSON, LongText, StringUUID, adjusted_json_index diff --git a/api/models/task.py b/api/models/task.py index 1e00e46643..539945b251 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -6,8 +6,8 @@ from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now -from models.base import TypeBase +from .base import TypeBase from .types import BinaryData, LongText diff --git a/api/models/tools.py b/api/models/tools.py index a4aeda93e5..0a79f95a70 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -12,8 +12,8 @@ from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from models.base import TypeBase +from .base import TypeBase from .engine import db from .model import Account, App, Tenant from .types import LongText, StringUUID diff --git a/api/models/trigger.py b/api/models/trigger.py index 92384b0d02..753fdb227b 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -237,7 +237,7 @@ class WorkflowTriggerLog(Base): @property def created_by_end_user(self): - from models.model import EndUser + from .model import EndUser created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None diff --git a/api/models/web.py b/api/models/web.py index 6e1a90af87..4f0bf7c7da 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -5,8 +5,7 @@ import sqlalchemy as sa from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column -from models.base import TypeBase - +from .base import TypeBase from .engine import db from .model import Message from .types import StringUUID diff --git a/api/models/workflow.py b/api/models/workflow.py index d8d3e1e540..3ebc36bee3 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -86,7 +86,7 @@ class WorkflowType(StrEnum): :param app_mode: app mode :return: workflow type """ - from models.model import AppMode + from .model import AppMode app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT @@ -413,7 +413,7 @@ class Workflow(Base): For accurate checking, use a direct query with tenant_id, app_id, and version. """ - from models.tools import WorkflowToolProvider + from .tools import WorkflowToolProvider stmt = select( exists().where( @@ -634,7 +634,7 @@ class WorkflowRun(Base): @property def created_by_end_user(self): - from models.model import EndUser + from .model import EndUser created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @@ -653,7 +653,7 @@ class WorkflowRun(Base): @property def message(self): - from models.model import Message + from .model import Message return ( db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() @@ -874,7 +874,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @property def created_by_end_user(self): - from models.model import EndUser + from .model import EndUser created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. @@ -1130,7 +1130,7 @@ class WorkflowAppLog(TypeBase): @property def created_by_end_user(self): - from models.model import EndUser + from .model import EndUser created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 037ef469d2..abfb4baeec 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -254,6 +254,8 @@ class DatasetService: external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) if not external_knowledge_api: raise ValueError("External API template not found.") + if external_knowledge_id is None: + raise ValueError("external_knowledge_id is required") external_knowledge_binding = ExternalKnowledgeBindings( tenant_id=tenant_id, dataset_id=dataset.id, diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 0c2a80d067..27936f6278 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -257,12 +257,16 @@ class ExternalDatasetService: db.session.add(dataset) db.session.flush() + if args.get("external_knowledge_id") is None: + raise ValueError("external_knowledge_id is required") + if args.get("external_knowledge_api_id") is None: + raise ValueError("external_knowledge_api_id is required") external_knowledge_binding = ExternalKnowledgeBindings( tenant_id=tenant_id, dataset_id=dataset.id, - external_knowledge_api_id=args.get("external_knowledge_api_id"), - external_knowledge_id=args.get("external_knowledge_id"), + external_knowledge_api_id=args.get("external_knowledge_api_id") or "", + external_knowledge_id=args.get("external_knowledge_id") or "", created_by=user_id, ) db.session.add(external_knowledge_binding) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 337181728c..cdbd2355ca 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -82,7 +82,12 @@ class HitTestingService: logger.debug("Hit testing retrieve in %s seconds", end - start) dataset_query = DatasetQuery( - dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + dataset_id=dataset.id, + content=query, + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, ) db.session.add(dataset_query) @@ -118,7 +123,12 @@ class HitTestingService: logger.debug("External knowledge hit testing retrieve in %s seconds", end - start) dataset_query = DatasetQuery( - dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + dataset_id=dataset.id, + content=query, + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, ) db.session.add(dataset_query) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index e490b7ed3c..a2c8e9118e 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -29,6 +29,8 @@ class OpsService: if not app: return None tenant_id = app.tenant_id + if trace_config_data.tracing_config is None: + raise ValueError("Tracing config cannot be None.") decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fed7a25e21..097d16e2a7 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1119,13 +1119,19 @@ class RagPipelineService: with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) - + if args.get("icon_info") is None: + args["icon_info"] = {} + if args.get("description") is None: + raise ValueError("Description is required") + if args.get("name") is None: + raise ValueError("Name is required") pipeline_customized_template = PipelineCustomizedTemplate( - name=args.get("name"), - description=args.get("description"), - icon=args.get("icon_info"), + name=args.get("name") or "", + description=args.get("description") or "", + icon=args.get("icon_info") or {}, tenant_id=pipeline.tenant_id, yaml_content=dsl, + install_count=0, position=max_position + 1 if max_position else 1, chunk_structure=dataset.chunk_structure, language="en-US", diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index d79ab71668..22025dd44a 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -322,9 +322,9 @@ class RagPipelineTransformService: datasource_info=data_source_info, input_data={}, created_by=document.created_by, - created_at=document.created_at, datasource_node_id=file_node_id, ) + document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) elif document.data_source_type == "notion_import": @@ -350,9 +350,9 @@ class RagPipelineTransformService: datasource_info=data_source_info, input_data={}, created_by=document.created_by, - created_at=document.created_at, datasource_node_id=notion_node_id, ) + document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) elif document.data_source_type == "website_crawl": @@ -379,8 +379,8 @@ class RagPipelineTransformService: datasource_info=data_source_info, input_data={}, created_by=document.created_by, - created_at=document.created_at, datasource_node_id=datasource_node_id, ) + document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log)