mirror of
https://github.com/langgenius/dify.git
synced 2026-05-12 07:37:09 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
8c6dda125f
@ -1,4 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -23,9 +24,9 @@ class ConversationRenamePayload(BaseModel):
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
|
||||
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
@ -69,11 +70,35 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Dataset schemas ---
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
@ -16,6 +15,7 @@ from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
@ -71,9 +71,6 @@ from ..wraps import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NOTE: Keep constants near the top of the module for discoverability.
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
@ -110,12 +107,6 @@ class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
|
||||
@ -10,6 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
|
||||
@ -3,10 +3,10 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@ -86,13 +86,6 @@ class AudioApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@ -100,15 +101,6 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
|
||||
@ -2,9 +2,9 @@ from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel):
|
||||
segment: SegmentUpdateArgs
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
|
||||
@ -92,7 +92,7 @@ class HumanInputFormApi(Resource):
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# TODO(QuantumGhost): forbid submission for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
|
||||
@ -3,10 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -41,19 +40,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: str = Field(description="Conversation UUID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@ -4,7 +4,12 @@ from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEve
|
||||
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||
from core.rag.entities.retrieval_settings import (
|
||||
KeywordSetting,
|
||||
RerankingModelConfig,
|
||||
VectorSetting,
|
||||
WeightedScoreConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Condition",
|
||||
@ -19,6 +24,7 @@ __all__ = [
|
||||
"MetadataFilteringCondition",
|
||||
"ParentMode",
|
||||
"PreProcessingRule",
|
||||
"RerankingModelConfig",
|
||||
"RetrievalSourceMetadata",
|
||||
"Rule",
|
||||
"Segmentation",
|
||||
|
||||
@ -1,4 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Canonical reranking model configuration.
|
||||
|
||||
Accepts both naming conventions:
|
||||
- reranking_provider_name / reranking_model_name (services layer)
|
||||
- provider / model (workflow layer via validation_alias)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
reranking_provider_name: str = Field(validation_alias="provider")
|
||||
reranking_model_name: str = Field(validation_alias="model")
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.reranking_provider_name
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self.reranking_model_name
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
|
||||
@ -4,21 +4,12 @@ from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.entities.retrieval_settings import WeightedScoreConfig
|
||||
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
|
||||
@ -5,20 +5,11 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
|
||||
from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig
|
||||
|
||||
__all__ = ["Condition"]
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
|
||||
@ -61,7 +61,7 @@ from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base, DefaultFieldsMixin, TypeBase
|
||||
from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@ -742,8 +742,8 @@ class WorkflowRun(Base):
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
lambda: WorkflowPause,
|
||||
primaryjoin=lambda: WorkflowRun.id == orm.foreign(WorkflowPause.workflow_run_id),
|
||||
uselist=False,
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
@ -1196,6 +1196,18 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
raise ValueError(f"invalid workflow app log created from value {value}")
|
||||
|
||||
|
||||
class WorkflowAppLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
created_from: WorkflowAppLogCreatedFrom
|
||||
created_by_role: CreatorUserRole
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowAppLog(TypeBase):
|
||||
"""
|
||||
Workflow App execution log, excluding workflow debugging records.
|
||||
@ -1273,8 +1285,8 @@ class WorkflowAppLog(TypeBase):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> WorkflowAppLogDict:
|
||||
result: WorkflowAppLogDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
@ -1285,6 +1297,7 @@ class WorkflowAppLog(TypeBase):
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class WorkflowArchiveLog(TypeBase):
|
||||
@ -1941,7 +1954,7 @@ def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
|
||||
class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
class WorkflowPause(DefaultFieldsDCMixin, TypeBase):
|
||||
"""
|
||||
WorkflowPause records the paused state and related metadata for a specific workflow run.
|
||||
|
||||
@ -1980,6 +1993,11 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# `resumed_at` records the timestamp when the suspended workflow was resumed.
|
||||
# It is set to `NULL` if the workflow has not been resumed.
|
||||
#
|
||||
@ -1988,25 +2006,23 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
resumed_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
# Relationship to WorkflowRun (uses lambda to resolve across Base/TypeBase registries)
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
lambda: WorkflowRun,
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
primaryjoin=lambda: WorkflowPause.workflow_run_id == WorkflowRun.id,
|
||||
back_populates="pause",
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase):
|
||||
__tablename__ = "workflow_pause_reasons"
|
||||
|
||||
# `pause_id` represents the identifier of the pause,
|
||||
@ -2049,16 +2065,20 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
|
||||
init=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
pause_id=pause_id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=pause_reason.form_id,
|
||||
node_id=pause_reason.node_id,
|
||||
)
|
||||
elif isinstance(pause_reason, SchedulingPause):
|
||||
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
|
||||
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
|
||||
|
||||
@ -41,7 +41,6 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
@ -744,12 +743,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
# Upload the state file
|
||||
|
||||
# Create the pause record
|
||||
pause_model = WorkflowPause()
|
||||
pause_model.id = str(uuidv7())
|
||||
pause_model.workflow_id = workflow_run.workflow_id
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
pause_model.state_object_key = state_obj_key
|
||||
pause_model.created_at = naive_utc_now()
|
||||
pause_model = WorkflowPause(
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_obj_key,
|
||||
)
|
||||
pause_reason_models = []
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
|
||||
@ -7,6 +7,11 @@ from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class RerankingModel(BaseModel):
|
||||
reranking_provider_name: str | None = None
|
||||
reranking_model_name: str | None = None
|
||||
|
||||
|
||||
class NotionIcon(BaseModel):
|
||||
type: str
|
||||
url: str | None = None
|
||||
@ -53,11 +58,6 @@ class ProcessRule(BaseModel):
|
||||
rules: Rule | None = None
|
||||
|
||||
|
||||
class RerankingModel(BaseModel):
|
||||
reranking_provider_name: str | None = None
|
||||
reranking_model_name: str | None = None
|
||||
|
||||
|
||||
class WeightVectorSetting(BaseModel):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
|
||||
@ -6,6 +6,24 @@ from core.rag.entities import KeywordSetting, VectorSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str | None = ""
|
||||
reranking_model_name: str | None = ""
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting | None
|
||||
keyword_setting: KeywordSetting | None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
icon_background: str | None = None
|
||||
@ -28,24 +46,6 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str | None = ""
|
||||
reranking_model_name: str | None = ""
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting | None
|
||||
keyword_setting: KeywordSetting | None
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
|
||||
@ -17,6 +17,7 @@ from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
@ -496,7 +497,13 @@ class MCPToolManageService:
|
||||
) as mcp_client:
|
||||
return mcp_client.list_tools()
|
||||
|
||||
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||
_ACTION_TO_OAUTH: dict[AuthActionType, OAuthDataType] = {
|
||||
AuthActionType.SAVE_CLIENT_INFO: OAuthDataType.CLIENT_INFO,
|
||||
AuthActionType.SAVE_TOKENS: OAuthDataType.TOKENS,
|
||||
AuthActionType.SAVE_CODE_VERIFIER: OAuthDataType.CODE_VERIFIER,
|
||||
}
|
||||
|
||||
def execute_auth_actions(self, auth_result: AuthResult) -> dict[str, str]:
|
||||
"""
|
||||
Execute the actions returned by the auth function.
|
||||
|
||||
@ -508,19 +515,13 @@ class MCPToolManageService:
|
||||
Returns:
|
||||
The response from the auth result
|
||||
"""
|
||||
from core.mcp.entities import AuthAction, AuthActionType
|
||||
|
||||
action: AuthAction
|
||||
for action in auth_result.actions:
|
||||
if action.provider_id is None or action.tenant_id is None:
|
||||
continue
|
||||
|
||||
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||
oauth_type = self._ACTION_TO_OAUTH.get(action.action_type)
|
||||
if oauth_type is not None:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, oauth_type)
|
||||
|
||||
return auth_result.response
|
||||
|
||||
|
||||
@ -220,7 +220,6 @@ class TestDeleteRunsWithRelated:
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
@ -280,7 +279,6 @@ class TestCountRunsWithRelated:
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
@ -544,7 +542,6 @@ class TestPrivateWorkflowPauseEntity:
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
@ -574,7 +571,6 @@ class TestPrivateWorkflowPauseEntity:
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
@ -606,7 +602,6 @@ class TestPrivateWorkflowPauseEntity:
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
@ -672,7 +667,6 @@ class TestBuildHumanInputRequiredReason:
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
|
||||
@ -198,7 +198,7 @@ class TestMessageListQuery:
|
||||
assert q.limit == 20
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
with pytest.raises(ValidationError, match="must be a valid UUID"):
|
||||
MessageListQuery(conversation_id="bad")
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
@ -216,7 +216,7 @@ class TestMessageListQuery:
|
||||
|
||||
def test_invalid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
with pytest.raises(ValidationError, match="must be a valid UUID"):
|
||||
MessageListQuery(conversation_id=cid, first_id="invalid")
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user