mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge branch 'main' into feat/marketplace-template
This commit is contained in:
commit
197d8cec10
@ -9,7 +9,7 @@ from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -268,22 +268,18 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any, Literal, cast
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -186,29 +186,14 @@ class DraftRagPipelineApi(Resource):
|
||||
|
||||
if "application/json" in content_type:
|
||||
payload_dict = console_ns.payload or {}
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
payload_dict = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
|
||||
@ -38,6 +38,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
@ -469,7 +470,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
@ -25,6 +24,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
inputs = node.process_data # contains request URL
|
||||
|
||||
if not inputs:
|
||||
inputs = json.loads(node.inputs) if node.inputs else {}
|
||||
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
|
||||
|
||||
node_span = start_span_no_context(
|
||||
name=node.title,
|
||||
@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
# End node span
|
||||
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
|
||||
outputs = json.loads(node.outputs) if node.outputs else {}
|
||||
outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {}
|
||||
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
outputs = self._parse_knowledge_retrieval_outputs(outputs)
|
||||
elif node.node_type == BuiltinNodeTypes.LLM:
|
||||
@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return {}, {}
|
||||
|
||||
try:
|
||||
data = json.loads(node.process_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = JSON_DICT_ADAPTER.validate_json(node.process_data)
|
||||
except (ValueError, TypeError):
|
||||
return {}, {}
|
||||
|
||||
inputs = self._parse_prompts(data.get("prompts"))
|
||||
|
||||
@ -11,8 +11,10 @@ from uuid import UUID, uuid4
|
||||
|
||||
from cachetools import LRUCache
|
||||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
@ -33,7 +35,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.ops.utils import JSON_DICT_ADAPTER, get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
@ -50,6 +52,14 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _AppTracingConfig(TypedDict, total=False):
|
||||
enabled: bool
|
||||
tracing_provider: str | None
|
||||
|
||||
|
||||
_app_tracing_config_adapter: TypeAdapter[_AppTracingConfig] = TypeAdapter(_AppTracingConfig)
|
||||
|
||||
|
||||
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
|
||||
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
|
||||
app_name = ""
|
||||
@ -468,7 +478,7 @@ class OpsTraceManager:
|
||||
if app is None:
|
||||
return None
|
||||
|
||||
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
|
||||
app_ops_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) if app.tracing else None
|
||||
if app_ops_trace_config is None:
|
||||
return None
|
||||
if not app_ops_trace_config.get("enabled"):
|
||||
@ -560,7 +570,7 @@ class OpsTraceManager:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
return {"enabled": False, "tracing_provider": None}
|
||||
app_trace_config = json.loads(app.tracing)
|
||||
app_trace_config = _app_tracing_config_adapter.validate_json(app.tracing)
|
||||
return app_trace_config
|
||||
|
||||
@staticmethod
|
||||
@ -636,7 +646,6 @@ class TraceTask:
|
||||
carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading
|
||||
large JSON blobs unnecessarily.
|
||||
"""
|
||||
import json
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
@ -658,7 +667,7 @@ class TraceTask:
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
outputs = json.loads(raw) if isinstance(raw, str) else raw
|
||||
outputs = JSON_DICT_ADAPTER.validate_json(raw) if isinstance(raw, str) else raw
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if not isinstance(outputs, dict):
|
||||
@ -1420,7 +1429,7 @@ class TraceTask:
|
||||
return {}
|
||||
|
||||
try:
|
||||
metadata = json.loads(message_data.message_metadata)
|
||||
metadata = JSON_DICT_ADAPTER.validate_json(message_data.message_metadata)
|
||||
usage = metadata.get("usage", {})
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
@ -1430,7 +1439,7 @@ class TraceTask:
|
||||
"llm_streaming_time_to_generate": time_to_generate,
|
||||
"is_streaming_request": time_to_first_token is not None,
|
||||
}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
except (ValueError, AttributeError):
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@ -3,11 +3,14 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.engine import db
|
||||
from models.model import Message
|
||||
|
||||
JSON_DICT_ADAPTER: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
def filter_none_values(data: dict[str, Any]) -> dict[str, Any]:
|
||||
new_data = {}
|
||||
|
||||
@ -10,6 +10,7 @@ from mysql.connector import Error as MySQLError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -178,9 +179,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(record["meta"])
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
@ -263,15 +262,13 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
# similarity = 1 / (1 + distance)
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(record["meta"])
|
||||
metadata["score"] = similarity
|
||||
metadata["distance"] = distance
|
||||
|
||||
if similarity >= score_threshold:
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Error processing search result: %s", e)
|
||||
continue
|
||||
|
||||
@ -306,9 +303,7 @@ class AlibabaCloudMySQLVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(record["meta"])
|
||||
metadata["score"] = float(record["score"])
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
@ -8,6 +8,7 @@ _import_err_msg = (
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -257,7 +258,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata = parse_metadata_json(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
@ -294,7 +295,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata = parse_metadata_json(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
|
||||
@ -29,6 +29,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as VDBField
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -173,15 +174,9 @@ class BaiduVector(BaseVector):
|
||||
score = row.get("score", 0.0)
|
||||
meta = row_data.get(VDBField.METADATA_KEY, {})
|
||||
|
||||
# Handle both JSON string and dict formats for backward compatibility
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
meta = json.loads(meta)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
try:
|
||||
meta = parse_metadata_json(meta)
|
||||
except (ValueError, TypeError):
|
||||
meta = {}
|
||||
|
||||
if score >= score_threshold:
|
||||
@ -200,7 +195,11 @@ class BaiduVector(BaseVector):
|
||||
raise
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
config = Configuration(
|
||||
credentials=BceCredentials(config.account, config.api_key),
|
||||
endpoint=config.endpoint,
|
||||
connection_timeout_in_mills=config.connection_timeout_in_mills,
|
||||
)
|
||||
client = MochowClient(config)
|
||||
return client
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
from clickzetta.connector.v0.connection import Connection # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
@ -357,18 +357,19 @@ class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
try:
|
||||
if raw_metadata:
|
||||
metadata = json.loads(raw_metadata)
|
||||
# First parse may yield a string (double-encoded JSON) so use json.loads
|
||||
first_pass = json.loads(raw_metadata)
|
||||
|
||||
# Handle double-encoded JSON
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
# Ensure we have a dict
|
||||
if not isinstance(metadata, dict):
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
logger.exception("JSON parsing failed for metadata")
|
||||
# Fallback: extract document_id with regex
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "")
|
||||
@ -930,17 +931,18 @@ class ClickzettaVector(BaseVector):
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
# First parse may yield a string (double-encoded JSON)
|
||||
first_pass = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
if isinstance(first_pass, str):
|
||||
metadata = parse_metadata_json(first_pass)
|
||||
elif isinstance(first_pass, dict):
|
||||
metadata = first_pass
|
||||
else:
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
|
||||
|
||||
@ -1,4 +1,24 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
def parse_metadata_json(raw: Any) -> dict[str, Any]:
|
||||
"""Parse metadata from a JSON string or pass through an existing dict.
|
||||
|
||||
Many VDB drivers return metadata as either a JSON string or an already-
|
||||
decoded dict depending on the column type and driver version.
|
||||
"""
|
||||
if raw is None or raw in ("", b""):
|
||||
return {}
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if not isinstance(raw, (str, bytes, bytearray)):
|
||||
return {}
|
||||
return _metadata_adapter.validate_json(raw)
|
||||
|
||||
|
||||
class Field(StrEnum):
|
||||
|
||||
@ -9,6 +9,7 @@ from psycopg import sql as psql
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -217,8 +218,7 @@ class HologresVector(BaseVector):
|
||||
text = row[2]
|
||||
meta = row[3]
|
||||
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
meta = parse_metadata_json(meta)
|
||||
|
||||
# Convert distance to similarity score (consistent with pgvector)
|
||||
score = 1 - distance
|
||||
@ -265,8 +265,7 @@ class HologresVector(BaseVector):
|
||||
meta = row[2]
|
||||
score = row[-1] # score is the last column from return_score
|
||||
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
meta = parse_metadata_json(meta)
|
||||
|
||||
meta["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.iris_config import IrisVectorConfig
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -269,7 +270,7 @@ class IrisVector(BaseVector):
|
||||
if len(row) >= 4:
|
||||
text, meta_str, score = row[1], row[2], float(row[3])
|
||||
if score >= score_threshold:
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
metadata = parse_metadata_json(meta_str)
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
@ -384,7 +385,7 @@ class IrisVector(BaseVector):
|
||||
meta_str = row[2]
|
||||
score_value = row[3]
|
||||
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
metadata = parse_metadata_json(meta_str)
|
||||
# Add score to metadata for hybrid search compatibility
|
||||
score = float(score_value) if score_value is not None else 0.0
|
||||
metadata["score"] = score
|
||||
|
||||
@ -9,6 +9,7 @@ from mo_vector.client import MoVectorClient # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -196,11 +197,7 @@ class MatrixoneVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
if isinstance(metadata, str):
|
||||
import json
|
||||
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(result.metadata)
|
||||
score = 1 - result.distance
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -366,8 +367,8 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
# Parse metadata JSON
|
||||
try:
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
|
||||
except json.JSONDecodeError:
|
||||
metadata = parse_metadata_json(metadata_str)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
|
||||
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -73,7 +73,8 @@ class TableStoreVector(BaseVector):
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
|
||||
metadata = parse_metadata_json(kv[Field.METADATA_KEY])
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
@ -311,7 +312,7 @@ class TableStoreVector(BaseVector):
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
metadata = parse_metadata_json(metadata_str)
|
||||
|
||||
metadata["score"] = search_hit.score
|
||||
|
||||
@ -371,7 +372,7 @@ class TableStoreVector(BaseVector):
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
metadata = parse_metadata_json(metadata_str)
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
@ -11,6 +11,7 @@ from tcvectordb.model import index as vdb_index # type: ignore
|
||||
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -286,13 +287,10 @@ class TencentVector(BaseVector):
|
||||
return docs
|
||||
|
||||
for result in res[0]:
|
||||
meta = result.get(self.field_metadata)
|
||||
if isinstance(meta, str):
|
||||
# Compatible with version 1.1.3 and below.
|
||||
meta = json.loads(meta)
|
||||
score = 1 - result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
raw_meta = result.get(self.field_metadata)
|
||||
# Compatible with version 1.1.3 and below: str means old driver.
|
||||
score = (1 - result.get("score", 0.0)) if isinstance(raw_meta, str) else result.get("score", 0.0)
|
||||
meta = parse_metadata_json(raw_meta)
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
|
||||
@ -9,7 +9,7 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -228,7 +228,7 @@ class TiDBVector(BaseVector):
|
||||
)
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
metadata = parse_metadata_json(meta)
|
||||
metadata["score"] = 1 - distance
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
@ -15,6 +15,7 @@ from volcengine.viking_db import ( # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -163,7 +164,7 @@ class VikingDBVector(BaseVector):
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(metadata)
|
||||
if metadata.get(key) == value:
|
||||
ids.append(result.id)
|
||||
return ids
|
||||
@ -189,9 +190,7 @@ class VikingDBVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
metadata = parse_metadata_json(result.fields.get(vdb_Field.METADATA_KEY))
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
|
||||
|
||||
@ -35,7 +35,10 @@ class IndexProcessor:
|
||||
if "parent_mode" in preview:
|
||||
data.parent_mode = preview["parent_mode"]
|
||||
|
||||
for item in preview["preview"]:
|
||||
# Different index processors return different preview shapes:
|
||||
# - paragraph/parent-child processors: {"preview": [...]}
|
||||
# - QA processor: {"qa_preview": [...]} (no "preview" key)
|
||||
for item in preview.get("preview", []):
|
||||
if "content" in item and "child_chunks" in item:
|
||||
data.preview.append(
|
||||
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
|
||||
@ -44,6 +47,10 @@ class IndexProcessor:
|
||||
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
|
||||
elif "content" in item:
|
||||
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
|
||||
|
||||
for item in preview.get("qa_preview", []):
|
||||
if "question" in item and "answer" in item:
|
||||
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
|
||||
return data
|
||||
|
||||
def index_and_clean(
|
||||
|
||||
@ -20,6 +20,7 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
@ -48,10 +49,10 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
|
||||
"""
|
||||
logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
|
||||
# Parse JSON fields
|
||||
inputs = json.loads(data.get("inputs", "{}"))
|
||||
process_data = json.loads(data.get("process_data", "{}"))
|
||||
outputs = json.loads(data.get("outputs", "{}"))
|
||||
metadata = json.loads(data.get("execution_metadata", "{}"))
|
||||
inputs = JSON_DICT_ADAPTER.validate_json(data.get("inputs") or "{}")
|
||||
process_data = JSON_DICT_ADAPTER.validate_json(data.get("process_data") or "{}")
|
||||
outputs = JSON_DICT_ADAPTER.validate_json(data.get("outputs") or "{}")
|
||||
metadata = JSON_DICT_ADAPTER.validate_json(data.get("execution_metadata") or "{}")
|
||||
|
||||
# Convert metadata to domain enum keys
|
||||
domain_metadata = {}
|
||||
|
||||
@ -15,8 +15,12 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
@ -455,8 +459,8 @@ class FileLifecycleManager:
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
result = _metadata_adapter.validate_json(metadata_content)
|
||||
return result or {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
_service_account_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
|
||||
class GoogleCloudStorage(BaseStorage):
|
||||
"""Implementation for Google Cloud storage."""
|
||||
@ -21,7 +24,7 @@ class GoogleCloudStorage(BaseStorage):
|
||||
if service_account_json_str:
|
||||
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
|
||||
# convert str to object
|
||||
service_account_obj = json.loads(service_account_json)
|
||||
service_account_obj = _service_account_adapter.validate_json(service_account_json)
|
||||
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
|
||||
else:
|
||||
self.client = google_cloud_storage.Client()
|
||||
|
||||
@ -132,7 +132,7 @@ class AudioService:
|
||||
uuid.UUID(message_id)
|
||||
except ValueError:
|
||||
return None
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
message = db.session.get(Message, message_id)
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
@ -158,10 +159,10 @@ class BillingService:
|
||||
def is_tenant_owner_or_admin(current_user: Account):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join: TenantAccountJoin | None = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
join: TenantAccountJoin | None = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not join:
|
||||
|
||||
@ -137,11 +137,11 @@ class ConversationService:
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
message = db.session.scalar(
|
||||
select(Message)
|
||||
.where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
@ -160,8 +160,8 @@ class ConversationService:
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
@ -170,7 +170,7 @@ class ConversationService:
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -29,13 +29,13 @@ class CreditPoolService:
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
return db.session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -4,6 +4,7 @@ import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from redis import RedisError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool:
|
||||
return True
|
||||
|
||||
# Fetch all workspaces the account belongs to
|
||||
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
|
||||
workspace_joins = db.session.scalars(
|
||||
select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id)
|
||||
).all()
|
||||
|
||||
# Queue sync task for each workspace
|
||||
success = True
|
||||
|
||||
@ -110,7 +110,7 @@ class PipelineGenerateService:
|
||||
Update document status to waiting
|
||||
:param document_id: document id
|
||||
"""
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
document = db.session.get(Document, document_id)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
pipeline_customized_templates = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
pipeline_customized_templates = db.session.scalars(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
recommended_pipeline_result = {
|
||||
@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id)
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import PipelineBuiltInTemplate
|
||||
@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
|
||||
pipeline_built_in_templates = list(
|
||||
db.session.scalars(
|
||||
select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language)
|
||||
).all()
|
||||
)
|
||||
|
||||
recommended_pipelines_results = []
|
||||
@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
|
||||
)
|
||||
pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id)
|
||||
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
app_model = db.session.query(App).where(App.id == app_id).first()
|
||||
app_model = db.session.get(App, app_id)
|
||||
if not app_model or not app_model.is_public:
|
||||
return None
|
||||
|
||||
|
||||
@ -64,15 +64,15 @@ class WebConversationService:
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
pinned_conversation = db.session.scalar(
|
||||
select(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if pinned_conversation:
|
||||
@ -96,15 +96,15 @@ class WebConversationService:
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
pinned_conversation = db.session.scalar(
|
||||
select(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not pinned_conversation:
|
||||
|
||||
@ -3,6 +3,7 @@ import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -92,10 +93,10 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def create_end_user(cls, app_code, email) -> EndUser:
|
||||
site = db.session.query(Site).where(Site.code == app_code).first()
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1))
|
||||
if not site:
|
||||
raise NotFound("Site not found.")
|
||||
app_model = db.session.query(App).where(App.id == site.app_id).first()
|
||||
app_model = db.session.get(App, site.app_id)
|
||||
if not app_model:
|
||||
raise NotFound("App not found.")
|
||||
end_user = EndUser(
|
||||
|
||||
@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -648,10 +649,10 @@ class WorkflowConverter:
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
api_based_extension = db.session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -24,10 +25,10 @@ class WorkspaceService:
|
||||
}
|
||||
|
||||
# Get role of user
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
tenant_account_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
@ -381,13 +381,22 @@ def test_init_client_constructs_configuration_and_client(baidu_module, monkeypat
|
||||
monkeypatch.setattr(baidu_module, "MochowClient", client_cls)
|
||||
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint")
|
||||
config = SimpleNamespace(
|
||||
account="account",
|
||||
api_key="key",
|
||||
endpoint="https://endpoint",
|
||||
connection_timeout_in_mills=12_345,
|
||||
)
|
||||
|
||||
client = vector._init_client(config)
|
||||
|
||||
assert client == "client"
|
||||
credentials.assert_called_once_with("account", "key")
|
||||
configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint")
|
||||
configuration.assert_called_once_with(
|
||||
credentials="credentials",
|
||||
endpoint="https://endpoint",
|
||||
connection_timeout_in_mills=12_345,
|
||||
)
|
||||
client_cls.assert_called_once_with("configuration")
|
||||
|
||||
|
||||
|
||||
45
api/tests/unit_tests/core/rag/datasource/vdb/test_field.py
Normal file
45
api/tests/unit_tests/core/rag/datasource/vdb/test_field.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||
|
||||
|
||||
class TestParseMetadataJson:
|
||||
def test_none_returns_empty_dict(self):
|
||||
assert parse_metadata_json(None) == {}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
assert parse_metadata_json("") == {}
|
||||
|
||||
def test_valid_json_string(self):
|
||||
result = parse_metadata_json('{"doc_id": "abc", "score": 0.9}')
|
||||
assert result == {"doc_id": "abc", "score": 0.9}
|
||||
|
||||
def test_dict_passthrough(self):
|
||||
original = {"doc_id": "abc", "document_id": "123"}
|
||||
result = parse_metadata_json(original)
|
||||
assert result == original
|
||||
|
||||
def test_empty_json_object(self):
|
||||
assert parse_metadata_json("{}") == {}
|
||||
|
||||
def test_invalid_json_raises_value_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
parse_metadata_json("{invalid json")
|
||||
|
||||
def test_nested_metadata(self):
|
||||
result = parse_metadata_json('{"doc_id": "1", "extra": {"nested": true}}')
|
||||
assert result["extra"]["nested"] is True
|
||||
|
||||
def test_non_str_non_dict_returns_empty_dict(self):
|
||||
assert parse_metadata_json(123) == {}
|
||||
assert parse_metadata_json([1, 2]) == {}
|
||||
|
||||
def test_bytes_input(self):
|
||||
result = parse_metadata_json(b'{"key": "value"}')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_empty_bytes_returns_empty_dict(self):
|
||||
assert parse_metadata_json(b"") == {}
|
||||
|
||||
def test_empty_bytearray_returns_empty_dict(self):
|
||||
assert parse_metadata_json(bytearray(b"")) == {}
|
||||
@ -0,0 +1,15 @@
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
|
||||
|
||||
class TestIndexProcessor:
|
||||
def test_format_preview_supports_qa_preview_shape(self) -> None:
|
||||
preview = IndexProcessor().format_preview(
|
||||
"qa_model",
|
||||
{"qa_chunks": [{"question": "Q1", "answer": "A1"}]},
|
||||
)
|
||||
|
||||
assert preview.chunk_structure == "qa_model"
|
||||
assert preview.total_segments == 1
|
||||
assert len(preview.qa_preview) == 1
|
||||
assert preview.qa_preview[0].question == "Q1"
|
||||
assert preview.qa_preview[0].answer == "A1"
|
||||
@ -421,11 +421,8 @@ class TestAudioServiceTTS:
|
||||
answer="Message answer text",
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
# Mock database lookup
|
||||
mock_db_session.get.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
@ -568,11 +565,8 @@ class TestAudioServiceTTS:
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
# Mock database lookup returning None
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
@ -594,11 +588,8 @@ class TestAudioServiceTTS:
|
||||
status=MessageStatus.NORMAL,
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
# Mock database lookup
|
||||
mock_db_session.get.return_value = message
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
|
||||
@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.OWNER
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act - should not raise exception
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
def test_is_tenant_owner_or_admin_admin(self, mock_db_session):
|
||||
"""Test tenant owner/admin check for admin role."""
|
||||
# Arrange
|
||||
@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.ADMIN
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act - should not raise exception
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session):
|
||||
"""Test tenant owner/admin check raises error for normal user."""
|
||||
# Arrange
|
||||
@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.NORMAL
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement:
|
||||
current_user.id = "account-123"
|
||||
current_user.current_tenant_id = "tenant-456"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases:
|
||||
mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases:
|
||||
mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
|
||||
@ -355,15 +355,13 @@ class TestConversationServiceGetConversation:
|
||||
from_account_id=user.id, from_source=ConversationFromSource.CONSOLE
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = conversation
|
||||
mock_db_session.scalar.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.get_conversation(app_model, "conv-123", user)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
mock_db_session.query.assert_called_once_with(Conversation)
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
def test_get_conversation_success_with_end_user(self, mock_db_session):
|
||||
@ -379,8 +377,7 @@ class TestConversationServiceGetConversation:
|
||||
from_end_user_id=user.id, from_source=ConversationFromSource.API
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = conversation
|
||||
mock_db_session.scalar.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.get_conversation(app_model, "conv-123", user)
|
||||
@ -399,8 +396,7 @@ class TestConversationServiceGetConversation:
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ConversationNotExistsError):
|
||||
@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
)
|
||||
|
||||
# Mock database query to return message
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = message
|
||||
mock_db_session.scalar.return_value = message
|
||||
|
||||
# Mock LLM generator
|
||||
mock_llm_generator.generate_conversation_name.return_value = "Generated Name"
|
||||
@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
# Mock database query to return None
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
)
|
||||
|
||||
# Mock database query to return message
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = message
|
||||
mock_db_session.scalar.return_value = message
|
||||
|
||||
# Mock LLM generator to raise exception
|
||||
mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user