mirror of
https://github.com/langgenius/dify.git
synced 2026-06-14 21:01:08 +08:00
feat: evaluation (#35353)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: hj24 <huangjian@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Ayush Baluni <73417844+aayushbaluni@users.noreply.github.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: jimcody1995 <jjimcody@gmail.com> Co-authored-by: James <63717587+jamesrayammons@users.noreply.github.com> Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai> Co-authored-by: Stephen Zhou <hi@hyoban.cc> Co-authored-by: Coding On Star <447357187@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: jerryzai <jerryzh8710@protonmail.com> Co-authored-by: NVIDIAN <speedy.hpc@hotmail.com> Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Junghwan <70629228+shaun0927@users.noreply.github.com> Co-authored-by: HeYinKazune <70251095+HeYin-OS@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yyh <yuanyouhuilyz@gmail.com> Co-authored-by: Jingyi <jingyi.qi@dify.ai> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: sxxtony <166789813+sxxtony@users.noreply.github.com>
This commit is contained in:
parent
90384b26b3
commit
0e320290e1
1
.github/workflows/autofix.yml
vendored
1
.github/workflows/autofix.yml
vendored
@ -120,7 +120,6 @@ jobs:
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
{
|
||||
// Disable the default formatter, use eslint instead
|
||||
"prettier.enable": false,
|
||||
"editor.formatOnSave": false,
|
||||
"cucumber.features": [
|
||||
"e2e/features/**/*.feature",
|
||||
],
|
||||
"cucumber.glue": [
|
||||
"e2e/features/**/*.ts",
|
||||
],
|
||||
|
||||
"tailwindCSS.experimental.configFile": "web/app/styles/globals.css",
|
||||
|
||||
// Auto fix
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.fixAll.eslint": "explicit",
|
||||
"source.organizeImports": "never"
|
||||
},
|
||||
|
||||
// Silent the stylistic rules in your IDE, but still auto fix them
|
||||
@ -106,3 +106,6 @@ msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
[lint.isort]
|
||||
known-first-party = ["graphon"]
|
||||
@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
|
||||
@ -125,6 +125,9 @@ from .explore import (
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@ -215,6 +218,9 @@ __all__ = [
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
@ -40,6 +39,7 @@ from fields.conversation_fields import (
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
||||
@ -5,10 +5,6 @@ from typing import Any, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -25,6 +21,10 @@ from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
|
||||
@ -3,12 +3,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -26,6 +25,7 @@ from core.evaluation.entities.evaluation_entity import EvaluationCategory, Evalu
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.member_fields import simple_account_fields
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
@ -45,6 +45,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
|
||||
@ -181,7 +184,7 @@ evaluation_default_metrics_response_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
|
||||
@ -190,7 +193,7 @@ def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
@ -15,6 +14,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
|
||||
@ -78,6 +78,13 @@ class SnippetDraftSyncPayload(BaseModel):
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import workflow_model
|
||||
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
@ -25,6 +25,7 @@ from controllers.console.snippets.payloads import (
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
@ -36,6 +37,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -46,6 +48,9 @@ from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -54,6 +59,7 @@ register_schema_models(
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
@ -70,7 +76,7 @@ class SnippetNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
@ -246,6 +252,40 @@ class SnippetDefaultBlockConfigsApi(Resource):
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
class SnippetPublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_snippet_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get all published workflow versions for snippet."""
|
||||
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
snippet_service = SnippetService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_model)
|
||||
|
||||
return {
|
||||
"items": serialized_workflows,
|
||||
"page": args.page,
|
||||
"limit": args.limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
|
||||
@ -12,11 +12,10 @@ Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from graphon.variables.types import SegmentType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -38,12 +37,16 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
@ -59,7 +62,7 @@ def _ensure_snippet_draft_variable_row_allowed(
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
def _snippet_draft_var_prerequisite(f: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any, Literal
|
||||
import pytz
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -40,6 +39,7 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
||||
@ -6,9 +6,6 @@ from typing import Literal
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@ -38,6 +35,9 @@ from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
|
||||
@ -9,12 +9,6 @@ from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -77,6 +71,12 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
|
||||
@ -2,9 +2,6 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.variables.input_entities import VariableEntityType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.draft_variable_saver import (
|
||||
@ -16,6 +13,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.variables.input_entities import VariableEntityType
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
@ -9,11 +9,10 @@ scope updates that matter to chat applications.
|
||||
|
||||
import logging
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent
|
||||
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -4,13 +4,6 @@ from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -60,6 +53,13 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file import helpers as file_helpers
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
@ -9,10 +9,6 @@ import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
@ -20,6 +16,9 @@ from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
|
||||
@ -12,10 +12,6 @@ from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
from opentelemetry import context as context_api
|
||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||
|
||||
@ -28,6 +24,10 @@ from extensions.otel.parser import (
|
||||
ToolNodeOTelParser,
|
||||
)
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -3,8 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
CustomizedMetrics,
|
||||
EvaluationCategory,
|
||||
@ -13,6 +11,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationMetric,
|
||||
NodeInfo,
|
||||
)
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -26,9 +26,10 @@ Typical usage::
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
|
||||
|
||||
class JudgmentCondition(BaseModel):
|
||||
"""A single judgment condition that checks one metric value.
|
||||
|
||||
@ -14,15 +14,14 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
from core.evaluation.entities.judgment_entity import (
|
||||
JudgmentCondition,
|
||||
JudgmentConditionResult,
|
||||
JudgmentConfig,
|
||||
JudgmentResult,
|
||||
)
|
||||
from graphon.utils.condition.entities import SupportedComparisonOperator
|
||||
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
@ -11,6 +9,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -11,13 +11,12 @@ persisting to the database) is handled by the evaluation task, not the runner.
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
@ -11,6 +9,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
@ -10,6 +8,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -8,8 +8,6 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
@ -17,6 +15,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
|
||||
from core.evaluation.entities.evaluation_entity import (
|
||||
DefaultMetric,
|
||||
@ -11,6 +9,7 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationItemResult,
|
||||
)
|
||||
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
@ -35,6 +34,7 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
@ -5,11 +5,6 @@ from collections.abc import Sequence
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
@ -35,6 +30,11 @@ from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class OpenAIModeration(Moderation):
|
||||
|
||||
@ -3,7 +3,6 @@ import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
CreateGenerationBody,
|
||||
@ -40,6 +39,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
from core.plugin.entities.plugin import PluginResourceRequirements
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
|
||||
|
||||
class MarketplacePluginDeclaration(BaseModel):
|
||||
|
||||
@ -6,13 +6,6 @@ from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
@ -21,6 +14,13 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2,9 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -11,6 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
from core.rag.rerank.rerank_factory import RerankRunnerFactory
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class RerankingModelDict(TypedDict):
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -12,6 +8,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
|
||||
@ -4,8 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@ -19,6 +17,8 @@ from core.workflow.human_input_compat import (
|
||||
InteractiveSurfaceDeliveryMethod,
|
||||
is_human_input_webapp_enabled,
|
||||
)
|
||||
from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
|
||||
@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ApiToolProviderNotFoundError(ValueError):
|
||||
error_code = "api_tool_provider_not_found"
|
||||
provider_name: str
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, provider_name: str, tenant_id: str):
|
||||
self.provider_name = provider_name
|
||||
self.tenant_id = tenant_id
|
||||
super().__init__(f"api provider {provider_name} does not exist")
|
||||
|
||||
|
||||
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
|
||||
error_code = "workflow_tool_human_input_not_supported"
|
||||
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
|
||||
|
||||
@ -14,12 +14,13 @@ from typing import Annotated, Any, ClassVar, Literal
|
||||
|
||||
import bleach
|
||||
import markdown
|
||||
from markdown.extensions.tables import TableExtension
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from graphon.runtime import VariablePool
|
||||
from graphon.variables.consts import SELECTORS_LENGTH
|
||||
from markdown.extensions.tables import TableExtension
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
|
||||
class DeliveryMethodType(enum.StrEnum):
|
||||
|
||||
@ -5,22 +5,6 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.nodes.code.limits import CodeNodeLimits
|
||||
from graphon.nodes.document_extractor import UnstructuredApiConfig
|
||||
from graphon.nodes.http_request import build_http_request_config
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -56,6 +40,22 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
|
||||
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.nodes.code.limits import CodeNodeLimits
|
||||
from graphon.nodes.document_extractor import UnstructuredApiConfig
|
||||
from graphon.nodes.http_request import build_http_request_config
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -4,6 +4,32 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.repositories.human_input_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormRepository,
|
||||
HumanInputFormRepositoryImpl,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.file import FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
@ -34,32 +60,6 @@ from graphon.nodes.tool_runtime_entities import (
|
||||
ToolRuntimeMessage,
|
||||
ToolRuntimeParameter,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.repositories.human_input_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormRepository,
|
||||
HumanInputFormRepositoryImpl,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
@ -76,13 +76,12 @@ from .human_input_compat import (
|
||||
from .system_variables import SystemVariableKey, get_system_text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
|
||||
from graphon.file import File
|
||||
from graphon.nodes.llm.file_saver import LLMFileSaver
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
|
||||
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
@ -3,15 +3,14 @@ from __future__ import annotations
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
|
||||
from .entities import AgentNodeData
|
||||
from .exceptions import (
|
||||
AgentInvocationError,
|
||||
|
||||
@ -3,6 +3,14 @@ from __future__ import annotations
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
@ -15,14 +23,6 @@ from graphon.node_events import (
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from graphon.variables.segments import ArrayFileSegment
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
|
||||
@ -4,8 +4,6 @@ import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.runtime import VariablePool
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
@ -21,6 +19,8 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.runtime import VariablePool
|
||||
from models.model import Conversation
|
||||
|
||||
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from graphon.model_runtime.entities import LLMUsage
|
||||
from graphon.nodes.llm.entities import ModelConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from graphon.model_runtime.entities import LLMUsage
|
||||
from graphon.nodes.llm.entities import ModelConfig
|
||||
|
||||
from .entities import MetadataFilteringCondition
|
||||
|
||||
|
||||
@ -3,11 +3,10 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Sandbox-backed Jinja2 renderer for workflow-owned node composition."""
|
||||
|
||||
@ -3,20 +3,6 @@ import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||
from graphon.errors import WorkflowNodeRunFailedError
|
||||
from graphon.file import File
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel
|
||||
from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
|
||||
from configs import dify_config
|
||||
from context import capture_current_context
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@ -40,6 +26,19 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add
|
||||
from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
from factories import file_factory
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||
from graphon.errors import WorkflowNodeRunFailedError
|
||||
from graphon.file import File
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel
|
||||
from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -3,10 +3,9 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from core.db.session_factory import session_factory
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
@ -12,5 +12,6 @@ def handle(sender, **kwargs):
|
||||
app_id=app.id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
)
|
||||
db.session.add(installed_app)
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
session.add(installed_app)
|
||||
session.commit()
|
||||
|
||||
@ -13,10 +13,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -26,6 +22,10 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
|
||||
@ -7,12 +7,12 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_database import db
|
||||
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .common import resolve_mapping_file_id
|
||||
|
||||
@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig
|
||||
from models import MessageFile
|
||||
|
||||
from .builders import build_from_mapping
|
||||
|
||||
@ -5,12 +5,12 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.file_access import FileAccessControllerProtocol
|
||||
from core.workflow.file_reference import build_file_reference, parse_file_reference
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
|
||||
|
||||
@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import File
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import File
|
||||
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
@ -4,10 +4,10 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Namespace, fields
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import field_validator
|
||||
|
||||
from fields.base import ResponseModel
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.helper import TimestampField
|
||||
|
||||
from ._value_type_serializer import serialize_value_type
|
||||
|
||||
@ -85,7 +85,7 @@ class EvaluationConfiguration(Base):
|
||||
"""Return judgment config (stored in the judgement_conditions column)."""
|
||||
if self.judgement_conditions:
|
||||
parsed = json.loads(self.judgement_conditions)
|
||||
return parsed or None
|
||||
return parsed if parsed else None
|
||||
return None
|
||||
|
||||
@property
|
||||
|
||||
@ -4,9 +4,8 @@ from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import File, FileTransferMethod
|
||||
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.file import File, FileTransferMethod
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@ -4,8 +4,6 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
@ -17,6 +15,8 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created, app_was_deleted, app_was_updated
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
@ -10,9 +10,6 @@ from collections.abc import Sequence
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -31,6 +28,9 @@ from events.dataset_event import dataset_was_deleted
|
||||
from events.document_event import document_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
|
||||
@ -4,8 +4,6 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
from openpyxl import Workbook, load_workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
@ -25,6 +23,8 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
NodeInfo,
|
||||
)
|
||||
from core.evaluation.evaluation_manager import EvaluationManager
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
from graphon.node_events.base import NodeRunResult
|
||||
from models.evaluation import (
|
||||
EvaluationConfiguration,
|
||||
EvaluationRun,
|
||||
@ -813,9 +813,9 @@ class EvaluationService:
|
||||
workflow_run_id: str,
|
||||
) -> dict[str, NodeRunResult]:
|
||||
"""Query all node execution records for a workflow run."""
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from sqlalchemy import asc, select
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
stmt = (
|
||||
|
||||
@ -4,13 +4,13 @@ from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from extensions.ext_database import db
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
|
||||
@ -3,8 +3,6 @@ import logging
|
||||
import time
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -12,6 +10,7 @@ from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -14,6 +13,7 @@ from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import (
|
||||
CustomConfigurationResponse,
|
||||
|
||||
@ -7,8 +7,6 @@ from enum import StrEnum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import yaml # type: ignore
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -17,6 +15,8 @@ from sqlalchemy.orm import Session
|
||||
from core.helper import ssrf_proxy
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet, SnippetType
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -23,13 +23,13 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.file.models import File
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from factories import file_factory
|
||||
from graphon.file.models import File
|
||||
from models import Account
|
||||
from models.model import AppMode, EndUser
|
||||
from models.snippet import CustomizedSnippet
|
||||
|
||||
@ -4,12 +4,12 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
@ -6,8 +6,6 @@ import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -18,6 +16,8 @@ from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
@ -349,7 +349,6 @@ class SummaryIndexService:
|
||||
summary_record_id,
|
||||
)
|
||||
summary_record_in_session = DocumentSegmentSummary(
|
||||
id=summary_record_id, # Use the same ID if available
|
||||
dataset_id=dataset.id,
|
||||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
@ -360,6 +359,9 @@ class SummaryIndexService:
|
||||
status=SummaryStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
if summary_record_in_session is None:
|
||||
raise RuntimeError("summary_record_in_session should not be None at this point")
|
||||
summary_record_in_session.id = summary_record_id
|
||||
session.add(summary_record_in_session)
|
||||
logger.info(
|
||||
"Created new summary record (id=%s) for segment %s after vectorization",
|
||||
|
||||
@ -3,7 +3,6 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -15,6 +14,7 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -14,6 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import (
|
||||
VisualConfig,
|
||||
)
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
|
||||
@ -7,9 +7,6 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
from flask import request
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.variables.types import ArrayValidation, SegmentType
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -31,6 +28,9 @@ from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.variables.types import ArrayValidation, SegmentType
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
@ -3,10 +3,10 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
|
||||
@ -10,7 +10,6 @@ from datetime import UTC, datetime
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing_extensions import TypedDict
|
||||
@ -24,6 +23,7 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.node_events import NodeRunResult
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
@ -28,6 +27,7 @@ from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluat
|
||||
from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner
|
||||
from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner
|
||||
from extensions.ext_database import db
|
||||
from graphon.node_events import NodeRunResult
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
|
||||
|
||||
@ -2,18 +2,17 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.protocols import ToolFileManagerProtocol
|
||||
from graphon.nodes.tool.tool_node import ToolNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_runtime import DifyToolNodeRuntime
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.protocols import ToolFileManagerProtocol
|
||||
from graphon.nodes.tool.tool_node import ToolNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from graphon.variables.segments import StringSegment
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import segment_to_variable
|
||||
from graphon.variables.segments import StringSegment
|
||||
from models import Workflow
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
|
||||
@ -6,7 +6,6 @@ from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.nodes.human_input.entities import FormDefinition, UserAction
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
|
||||
@ -10,10 +10,10 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod
|
||||
from models.model import App, AppMode, Conversation, Message
|
||||
|
||||
|
||||
|
||||
@ -9,9 +9,9 @@ from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.model import App, AppMode, Conversation, Message, Site
|
||||
from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType
|
||||
|
||||
@ -4,13 +4,13 @@ from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import insert, select
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
||||
from sqlalchemy.sql.sqltypes import VARCHAR
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.types import EnumText
|
||||
|
||||
_USER_TABLE = "enum_text_users"
|
||||
|
||||
@ -12,11 +12,11 @@ from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from graphon.nodes.human_input.enums import HumanInputFormStatus
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from graphon.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from graphon.nodes.human_input.enums import HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
|
||||
@ -7,6 +7,11 @@ from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
@ -14,11 +19,6 @@ from graphon.enums import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@ -7,12 +7,12 @@ from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from sqlalchemy import Engine, delete
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun, WorkflowType
|
||||
|
||||
@ -0,0 +1,524 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from graphon.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import App, Conversation, EndUser
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import (
|
||||
ConversationVariableNotExistsError,
|
||||
ConversationVariableTypeMismatchError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
class ConversationServiceVariableIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_app_and_account(db_session_with_containers):
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"conversation-variable-{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app, account
|
||||
|
||||
@staticmethod
|
||||
def create_end_user(db_session_with_containers, app: App):
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=InvokeFrom.SERVICE_API.value,
|
||||
external_user_id=f"external-{uuid4()}",
|
||||
name=f"End User {uuid4()}",
|
||||
is_anonymous=False,
|
||||
session_id=f"session-{uuid4()}",
|
||||
)
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
return end_user
|
||||
|
||||
@staticmethod
|
||||
def create_conversation(
|
||||
db_session_with_containers,
|
||||
app: App,
|
||||
user: Account | EndUser,
|
||||
*,
|
||||
name: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.WEB_APP,
|
||||
created_at: datetime | None = None,
|
||||
updated_at: datetime | None = None,
|
||||
) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=app.mode,
|
||||
name=name or f"Conversation {uuid4()}",
|
||||
summary="",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=invoke_from.value,
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
is_deleted=False,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
if created_at is not None:
|
||||
conversation.created_at = created_at
|
||||
if updated_at is not None:
|
||||
conversation.updated_at = updated_at
|
||||
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
def create_variable(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
variable: StringVariable | FloatVariable | IntegerVariable,
|
||||
created_at: datetime | None = None,
|
||||
) -> ConversationVariable:
|
||||
row = ConversationVariable.from_variable(app_id=app.id, conversation_id=conversation.id, variable=variable)
|
||||
if created_at is not None:
|
||||
row.created_at = created_at
|
||||
row.updated_at = created_at
|
||||
|
||||
db_session_with_containers.add(row)
|
||||
db_session_with_containers.commit()
|
||||
return row
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_conversation_service_session_factory(flask_app_with_containers):
|
||||
del flask_app_with_containers
|
||||
real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with (
|
||||
patch("services.conversation_service.session_factory.create_session", side_effect=lambda: real_session_maker()),
|
||||
patch("services.conversation_service.session_factory.get_session_maker", return_value=real_session_maker),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationServiceVariables:
|
||||
def test_get_conversational_variable_success(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
older_time = datetime(2024, 1, 1, 12, 0, 0)
|
||||
newer_time = older_time + timedelta(minutes=5)
|
||||
|
||||
first_variable = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="topic", value="billing"),
|
||||
created_at=older_time,
|
||||
)
|
||||
second_variable = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="priority", value="high"),
|
||||
created_at=newer_time,
|
||||
)
|
||||
|
||||
result = ConversationService.get_conversational_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
user=account,
|
||||
limit=10,
|
||||
last_id=None,
|
||||
)
|
||||
|
||||
assert [item["id"] for item in result.data] == [first_variable.id, second_variable.id]
|
||||
assert [item["name"] for item in result.data] == ["topic", "priority"]
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_conversational_variable_with_last_id(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
base_time = datetime(2024, 1, 1, 9, 0, 0)
|
||||
|
||||
first_variable = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="topic", value="billing"),
|
||||
created_at=base_time,
|
||||
)
|
||||
second_variable = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="priority", value="high"),
|
||||
created_at=base_time + timedelta(minutes=1),
|
||||
)
|
||||
third_variable = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="owner", value="alice"),
|
||||
created_at=base_time + timedelta(minutes=2),
|
||||
)
|
||||
|
||||
result = ConversationService.get_conversational_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
user=account,
|
||||
limit=10,
|
||||
last_id=first_variable.id,
|
||||
)
|
||||
|
||||
assert [item["id"] for item in result.data] == [second_variable.id, third_variable.id]
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_conversational_variable_last_id_not_found_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
|
||||
with pytest.raises(ConversationVariableNotExistsError):
|
||||
ConversationService.get_conversational_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
user=account,
|
||||
limit=10,
|
||||
last_id=str(uuid4()),
|
||||
)
|
||||
|
||||
def test_get_conversational_variable_sets_has_more(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
|
||||
for index in range(3):
|
||||
factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name=f"var_{index}", value=f"value_{index}"),
|
||||
created_at=datetime(2024, 1, 1, 10, 0, index),
|
||||
)
|
||||
|
||||
result = ConversationService.get_conversational_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
user=account,
|
||||
limit=2,
|
||||
last_id=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 2
|
||||
assert result.has_more is True
|
||||
|
||||
def test_update_conversation_variable_success(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
existing = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=StringVariable(id=str(uuid4()), name="topic", value="billing"),
|
||||
)
|
||||
updated_at = datetime(2024, 1, 1, 15, 0, 0)
|
||||
|
||||
with patch("services.conversation_service.naive_utc_now", return_value=updated_at):
|
||||
result = ConversationService.update_conversation_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
variable_id=existing.id,
|
||||
user=account,
|
||||
new_value="support",
|
||||
)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id))
|
||||
|
||||
assert persisted is not None
|
||||
assert persisted.to_variable().value == "support"
|
||||
assert result["id"] == existing.id
|
||||
assert result["value"] == "support"
|
||||
assert result["updated_at"] == updated_at
|
||||
|
||||
def test_update_conversation_variable_not_found_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
|
||||
with pytest.raises(ConversationVariableNotExistsError):
|
||||
ConversationService.update_conversation_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
variable_id=str(uuid4()),
|
||||
user=account,
|
||||
new_value="support",
|
||||
)
|
||||
|
||||
def test_update_conversation_variable_type_mismatch_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
existing = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=FloatVariable(id=str(uuid4()), name="score", value=1.5),
|
||||
)
|
||||
|
||||
with pytest.raises(ConversationVariableTypeMismatchError, match="expects float"):
|
||||
ConversationService.update_conversation_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
variable_id=existing.id,
|
||||
user=account,
|
||||
new_value="wrong-type",
|
||||
)
|
||||
|
||||
def test_update_conversation_variable_integer_number_compatibility(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
conversation = factory.create_conversation(db_session_with_containers, app, account)
|
||||
existing = factory.create_variable(
|
||||
db_session_with_containers,
|
||||
app=app,
|
||||
conversation=conversation,
|
||||
variable=IntegerVariable(id=str(uuid4()), name="attempts", value=1),
|
||||
)
|
||||
|
||||
result = ConversationService.update_conversation_variable(
|
||||
app_model=app,
|
||||
conversation_id=conversation.id,
|
||||
variable_id=existing.id,
|
||||
user=account,
|
||||
new_value=42,
|
||||
)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id))
|
||||
|
||||
assert persisted is not None
|
||||
assert persisted.to_variable().value == 42
|
||||
assert result["value"] == 42
|
||||
|
||||
|
||||
class TestConversationServicePaginationWithContainers:
|
||||
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
|
||||
with pytest.raises(LastConversationNotExistsError):
|
||||
ConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=str(uuid4()),
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
base_time = datetime(2024, 1, 1, 8, 0, 0)
|
||||
newest = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
account,
|
||||
name="Newest",
|
||||
updated_at=base_time + timedelta(minutes=2),
|
||||
)
|
||||
middle = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
account,
|
||||
name="Middle",
|
||||
updated_at=base_time + timedelta(minutes=1),
|
||||
)
|
||||
oldest = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
account,
|
||||
name="Oldest",
|
||||
updated_at=base_time,
|
||||
)
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=middle.id,
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
assert newest.id != middle.id
|
||||
assert [conversation.id for conversation in result.data] == [oldest.id]
|
||||
|
||||
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha")
|
||||
beta = factory.create_conversation(db_session_with_containers, app, account, name="Beta")
|
||||
gamma = factory.create_conversation(db_session_with_containers, app, account, name="Gamma")
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=beta.id,
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
sort_by="name",
|
||||
)
|
||||
|
||||
assert alpha.id != beta.id
|
||||
assert [conversation.id for conversation in result.data] == [gamma.id]
|
||||
|
||||
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
end_user = factory.create_end_user(db_session_with_containers, app)
|
||||
account_conversation = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
account,
|
||||
name="Console Conversation",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
end_user_conversation = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
end_user,
|
||||
name="API Conversation",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=end_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
assert account_conversation.id != end_user_conversation.id
|
||||
assert [conversation.id for conversation in result.data] == [end_user_conversation.id]
|
||||
|
||||
def test_pagination_filters_to_account_console_source(self, db_session_with_containers):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
end_user = factory.create_end_user(db_session_with_containers, app)
|
||||
account_conversation = factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
account,
|
||||
name="Console Conversation",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
factory.create_conversation(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
end_user,
|
||||
name="API Conversation",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
assert [conversation.id for conversation in result.data] == [account_conversation.id]
|
||||
@ -3,10 +3,10 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.variables import StringVariable
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.variables import StringVariable
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
@ -0,0 +1,650 @@
|
||||
"""Testcontainers integration tests for SQL-backed DocumentService paths."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
class DocumentServiceIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str | None = None,
|
||||
created_by: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
name=name or f"dataset-{uuid4()}",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by or str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
dataset: Dataset,
|
||||
name: str = "doc.txt",
|
||||
position: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
indexing_status: str = IndexingStatus.COMPLETED,
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
is_paused: bool = False,
|
||||
need_summary: bool = False,
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
batch: str | None = None,
|
||||
data_source_type: str = DataSourceType.UPLOAD_FILE,
|
||||
data_source_info: dict | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> Document:
|
||||
document = Document(
|
||||
tenant_id=tenant_id or dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=data_source_type,
|
||||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=batch or f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or dataset.created_by,
|
||||
doc_form=doc_form,
|
||||
)
|
||||
document.indexing_status = indexing_status
|
||||
document.enabled = enabled
|
||||
document.archived = archived
|
||||
document.is_paused = is_paused
|
||||
document.need_summary = need_summary
|
||||
if indexing_status == IndexingStatus.COMPLETED:
|
||||
document.completed_at = FIXED_UPLOAD_CREATED_AT
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
file_id: str | None = None,
|
||||
name: str = "source.txt",
|
||||
) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
created_at=FIXED_UPLOAD_CREATED_AT,
|
||||
used=False,
|
||||
)
|
||||
if file_id:
|
||||
upload_file.id = file_id
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user_mock():
|
||||
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
|
||||
current_user.id = str(uuid4())
|
||||
current_user.current_tenant_id = str(uuid4())
|
||||
current_user.current_role = None
|
||||
yield current_user
|
||||
|
||||
|
||||
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_document(dataset.id, None) is None
|
||||
|
||||
|
||||
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
|
||||
|
||||
result = DocumentService.get_document(dataset.id, document.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == document.id
|
||||
|
||||
|
||||
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
result = DocumentService.get_documents_by_ids(dataset.id, [])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
|
||||
doc_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
name="b.txt",
|
||||
position=2,
|
||||
)
|
||||
|
||||
result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id])
|
||||
|
||||
assert {document.id for document in result} == {doc_a.id, doc_b.id}
|
||||
|
||||
|
||||
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
|
||||
|
||||
|
||||
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
paragraph_doc = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
need_summary=True,
|
||||
)
|
||||
qa_doc = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
need_summary=True,
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
)
|
||||
|
||||
updated_count = DocumentService.update_documents_need_summary(
|
||||
dataset.id,
|
||||
[paragraph_doc.id, qa_doc.id],
|
||||
need_summary=False,
|
||||
)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id)
|
||||
refreshed_qa = db_session_with_containers.get(Document, qa_doc.id)
|
||||
assert updated_count == 1
|
||||
assert refreshed_paragraph is not None
|
||||
assert refreshed_qa is not None
|
||||
assert refreshed_paragraph.need_summary is False
|
||||
assert refreshed_qa.need_summary is True
|
||||
|
||||
|
||||
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url:
|
||||
result = DocumentService.get_document_download_url(document)
|
||||
|
||||
assert result == "signed-url"
|
||||
get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_type=DataSourceType.WEBSITE_CRAWL,
|
||||
data_source_info={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="invalid source"):
|
||||
DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="missing file"):
|
||||
DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": 99},
|
||||
)
|
||||
|
||||
result = DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
assert result == "99"
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": "missing-file"},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}):
|
||||
with pytest.raises(NotFound, match="Uploaded file not found"):
|
||||
DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
result = DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with pytest.raises(NotFound, match="Document not found"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[str(uuid4())],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
tenant_id=str(uuid4()),
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with pytest.raises(Forbidden, match="No permission"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": str(uuid4())},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
mapping = DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document_a.id, document_b.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
assert mapping[document_a.id].id == upload_file_a.id
|
||||
assert mapping[document_b.id].id == upload_file_b.id
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(
|
||||
current_user_mock, flask_app_with_containers
|
||||
):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=str(uuid4()),
|
||||
document_ids=[str(uuid4())],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
|
||||
db_session_with_containers,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
created_by=current_user_mock.id,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"services.dataset_service.DatasetService.check_dataset_permission",
|
||||
side_effect=NoPermissionError("denied"),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="denied"):
|
||||
DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
|
||||
db_session_with_containers,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
created_by=current_user_mock.id,
|
||||
)
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document_b.id, document_a.id],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id]
|
||||
assert download_name.endswith(".zip")
|
||||
|
||||
|
||||
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
enabled_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
result = DocumentService.get_document_by_dataset_id(dataset.id)
|
||||
|
||||
assert [document.id for document in result] == [enabled_document.id]
|
||||
|
||||
|
||||
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
available_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
result = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
|
||||
assert [document.id for document in result] == [available_document.id]
|
||||
|
||||
|
||||
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
error_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
paused_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
indexing_status=IndexingStatus.PAUSED,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=3,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
result = DocumentService.get_error_documents_by_dataset_id(dataset.id)
|
||||
|
||||
assert {document.id for document in result} == {error_document.id, paused_document.id}
|
||||
|
||||
|
||||
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
batch = f"batch-{uuid4()}"
|
||||
matching_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
batch=batch,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
tenant_id=str(uuid4()),
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
|
||||
current_user.current_tenant_id = dataset.tenant_id
|
||||
result = DocumentService.get_batch_documents(dataset.id, batch)
|
||||
|
||||
assert [document.id for document in result] == [matching_document.id]
|
||||
|
||||
|
||||
def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
|
||||
result = DocumentService.get_document_file_detail(upload_file.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_delete_document_emits_signal_and_commits(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.document_was_deleted.send") as signal_send:
|
||||
DocumentService.delete_document(document)
|
||||
|
||||
assert db_session_with_containers.get(Document, document.id) is None
|
||||
signal_send.assert_called_once_with(
|
||||
document.id,
|
||||
dataset_id=document.dataset_id,
|
||||
doc_form=document.doc_form,
|
||||
file_id=upload_file.id,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_documents_ignores_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
|
||||
DocumentService.delete_documents(dataset, [])
|
||||
|
||||
delay.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
|
||||
db_session_with_containers.commit()
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
|
||||
DocumentService.delete_documents(dataset, [document_a.id, document_b.id])
|
||||
|
||||
assert db_session_with_containers.get(Document, document_a.id) is None
|
||||
assert db_session_with_containers.get(Document, document_b.id) is None
|
||||
delay.assert_called_once()
|
||||
args = delay.call_args.args
|
||||
assert args[0] == [document_a.id, document_b.id]
|
||||
assert args[1] == dataset.id
|
||||
assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
|
||||
|
||||
|
||||
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 4
|
||||
|
||||
|
||||
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 1
|
||||
@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from graphon.runtime import VariablePool
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from configs import dify_config
|
||||
@ -16,6 +15,7 @@ from core.workflow.human_input_compat import (
|
||||
ExternalRecipient,
|
||||
MemberRecipient,
|
||||
)
|
||||
from graphon.runtime import VariablePool
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from services import human_input_delivery_test_service as service_module
|
||||
from services.human_input_delivery_test_service import (
|
||||
|
||||
@ -8,11 +8,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from graphon.file import FileType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.file import FileType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import (
|
||||
ConversationFromSource,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import inspect
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.errors import ApiToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@ -590,30 +594,204 @@ class TestApiToolManageService:
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
def test_update_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when original provider not found."""
|
||||
fake = Faker()
|
||||
|
||||
# Firmware fix for cache.delete() in update flow
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.delete.return_value = None
|
||||
mock_encrypter.return_value = (mock_encrypter, mock_cache)
|
||||
|
||||
# Get fake account and tenant
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
# original provider name
|
||||
original_name = "original-provider"
|
||||
|
||||
# Create original provider
|
||||
_ = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=original_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=["old-label"],
|
||||
)
|
||||
|
||||
# new provide name and new labels for update
|
||||
new_name = "updated-provider"
|
||||
new_labels = ["new-label-1", "new-label-2"]
|
||||
|
||||
# Reset mock history so assertions focus on update path only
|
||||
mock_external_service_dependencies["encrypter"].reset_mock()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
|
||||
|
||||
# Act: Update the provider with new values
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
# new provider name - changed 1
|
||||
provider_name=new_name,
|
||||
original_provider=original_name,
|
||||
# new icon - changed 2
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
# new privacy policy - changed 3
|
||||
privacy_policy="https://new-policy.com",
|
||||
# new custom disclaimer - changed 4
|
||||
custom_disclaimer="New disclaimer",
|
||||
# new labels - changed 5 (However, we will not verify this, not this layer responsibility.)
|
||||
labels=new_labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Get the updated provider from the database
|
||||
updated_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Verify the provider was updated successfully
|
||||
assert updated_provider is not None
|
||||
|
||||
# Manually refresh to keep object detachment
|
||||
db_session_with_containers.refresh(updated_provider)
|
||||
# Verify all the updated fields
|
||||
# - changed 1
|
||||
assert updated_provider.name == new_name
|
||||
# - changed 2
|
||||
icon_data = json.loads(updated_provider.icon)
|
||||
assert icon_data["type"] == "emoji"
|
||||
assert icon_data["value"] == "🚀"
|
||||
# - changed 3
|
||||
assert updated_provider.privacy_policy == "https://new-policy.com"
|
||||
# - changed 4
|
||||
assert updated_provider.custom_disclaimer == "New disclaimer"
|
||||
|
||||
# Verify old provider name no longer exists after rename
|
||||
original_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name)
|
||||
.first()
|
||||
)
|
||||
assert original_provider is None
|
||||
|
||||
# Verify update flow calls critical collaborators
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
# Deeply verify on session propagation of labels update logics:
|
||||
# Since in refactoring, we pass session down to label manager to keep atomicity.
|
||||
# The assertion here is to verify this.
|
||||
sig = inspect.signature(ToolLabelManager.update_tool_labels)
|
||||
args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
passed_session = bound_args.arguments.get("session")
|
||||
# Ensure the type: Session
|
||||
assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}"
|
||||
assert passed_session is not None, (
|
||||
"Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider"
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update raises ValueError when original provider not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error when trying to update a non-existing original provider
|
||||
- No accidental upsert/new provider creation
|
||||
- No external dependency invocation on early failure path
|
||||
"""
|
||||
# Arrange: Create test account and tenant
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Keep an existing provider in DB to ensure unrelated data remains unchanged
|
||||
existing_provider_name = "existing-provider"
|
||||
_ = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=existing_provider_name,
|
||||
icon={"type": "emoji", "value": "🔧"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy="https://existing-policy.com",
|
||||
custom_disclaimer="Existing disclaimer",
|
||||
labels=["existing-label"],
|
||||
)
|
||||
|
||||
# Reset mock history so assertions focus on update failure path only
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
|
||||
mock_external_service_dependencies["encrypter"].reset_mock()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
|
||||
|
||||
# Act & Assert: Verify update fails with clear error message
|
||||
target_new_name = "new-provider-name"
|
||||
missing_original_name = "missing-original-provider"
|
||||
with pytest.raises(ApiToolProviderNotFoundError) as exc_info:
|
||||
_ = ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name="new-name",
|
||||
original_provider="nonexistent",
|
||||
icon={},
|
||||
provider_name=target_new_name,
|
||||
original_provider=missing_original_name,
|
||||
icon={"type": "emoji", "value": "🚀"},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
privacy_policy="https://new-policy.com",
|
||||
custom_disclaimer="New disclaimer",
|
||||
labels=["new-label"],
|
||||
)
|
||||
|
||||
error = exc_info.value
|
||||
assert error.provider_name == missing_original_name
|
||||
assert error.tenant_id == tenant.id
|
||||
assert error.error_code == "api_tool_provider_not_found"
|
||||
|
||||
# Assert: Existing provider should remain unchanged
|
||||
existing_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name)
|
||||
.first()
|
||||
)
|
||||
assert existing_provider is not None
|
||||
assert existing_provider.name == existing_provider_name
|
||||
|
||||
# Assert: No new provider should be created
|
||||
unexpected_new_provider: ApiToolProvider | None = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name)
|
||||
.first()
|
||||
)
|
||||
assert unexpected_new_provider is None
|
||||
|
||||
# Assert: Early failure should skip all downstream external interactions
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called()
|
||||
mock_external_service_dependencies["encrypter"].assert_not_called()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_not_called()
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -5,9 +5,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -21,6 +18,9 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models import Account, Tenant
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
|
||||
@ -4,7 +4,6 @@ import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
@ -21,6 +20,7 @@ from controllers.console.app.error import (
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
|
||||
@ -5,10 +5,10 @@ from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app import conversation_variables as conversation_variables_module
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
|
||||
@ -1,6 +1,25 @@
|
||||
import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerResponse
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class _ValidatedResponse:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def model_dump(self, mode="json"):
|
||||
return self._payload
|
||||
|
||||
|
||||
class TestAppMCPServerResponse:
|
||||
@ -40,6 +59,18 @@ class TestAppMCPServerResponse:
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.parameters == {"already": "parsed"}
|
||||
|
||||
def test_parameters_json_array_parsed(self):
|
||||
data = {
|
||||
"id": "s1",
|
||||
"name": "test",
|
||||
"server_code": "code",
|
||||
"description": "desc",
|
||||
"status": "active",
|
||||
"parameters": '["a", "b"]',
|
||||
}
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.parameters == ["a", "b"]
|
||||
|
||||
def test_timestamps_normalized(self):
|
||||
dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
|
||||
data = {
|
||||
@ -68,3 +99,40 @@ class TestAppMCPServerResponse:
|
||||
resp = AppMCPServerResponse.model_validate(data)
|
||||
assert resp.created_at is None
|
||||
assert resp.updated_at is None
|
||||
|
||||
|
||||
class TestAppMCPServerController:
|
||||
def test_get_returns_empty_dict_when_server_missing(self):
|
||||
api = AppMCPServerController()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None):
|
||||
response = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response == {}
|
||||
|
||||
def test_post_returns_201(self):
|
||||
api = AppMCPServerController()
|
||||
method = unwrap(api.post)
|
||||
payload = {"parameters": {"timeout": 30}}
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")),
|
||||
patch("controllers.console.app.mcp_server.db.session.add"),
|
||||
patch("controllers.console.app.mcp_server.db.session.commit"),
|
||||
patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"),
|
||||
patch(
|
||||
"controllers.console.app.mcp_server.AppMCPServerResponse.model_validate",
|
||||
return_value=_ValidatedResponse({"id": "server-1"}),
|
||||
),
|
||||
):
|
||||
response, status_code = method(
|
||||
api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
|
||||
)
|
||||
|
||||
assert response == {"id": "server-1"}
|
||||
assert status_code == 201
|
||||
|
||||
@ -6,11 +6,11 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
|
||||
@ -2,9 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
from controllers.console.app import workflow_app_log as workflow_app_log_module
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
def test_workflow_app_log_query_parses_bool_and_datetime():
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -18,6 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user