Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-10 13:49:26 +08:00
commit 4680535ecd
45 changed files with 248 additions and 262 deletions

View File

@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id
external_knowledge_api_id, current_tenant_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from graphon.file import File, FileUploadConfig
from graphon.model_runtime.entities.model_entities import AIModelEntity
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Optional, TypedDict
from typing import Any, Literal, TypedDict
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator
@ -17,7 +17,7 @@ class DatasourceApiEntity(BaseModel):
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
class DatasourceProviderApiEntityDict(TypedDict):

View File

@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Execute a function with authentication retry logic.

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
class RequestContext[SessionT: BaseSession, LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -172,10 +172,10 @@ class ModelInstance:
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=tools,
stop=stop,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
stream=stream,
callbacks=callbacks,
),
@ -193,15 +193,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
@ -216,15 +213,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
@ -241,15 +235,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
@ -261,14 +252,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
)
def invoke_rerank(
@ -289,23 +277,20 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@ -320,17 +305,14 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_moderation(self, text: str) -> bool:
@ -342,14 +324,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
)
def invoke_speech2text(self, file: IO[bytes]) -> str:
@ -361,14 +340,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
)
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
@ -381,18 +357,15 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
@ -430,9 +403,8 @@ class ModelInstance:
continue
try:
if "credentials" in kwargs:
del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials)
kwargs["credentials"] = lb_config.credentials
return function(*args, **kwargs)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
self.delete_by_uuids(ids)
def delete(self):
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:

View File

@ -6,7 +6,7 @@ import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
with sessionmaker(bind=self._engine).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY,
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
);
""")
session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
return []
def delete(self):
with Session(self._engine) as session:
with sessionmaker(bind=self._engine).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:

View File

@ -3,8 +3,7 @@
import logging
import re
import uuid
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, TypedDict, cast
logger = logging.getLogger(__name__)
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()
class ParagraphFormatPreviewDict(TypedDict):
chunk_structure: str
preview: list[dict[str, Any]]
total_segments: int
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {
result: ParagraphFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
return result
else:
raise ValueError("Chunks is not a list")

View File

@ -3,8 +3,7 @@
import json
import logging
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from sqlalchemy import delete, select
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildFormatPreviewDict(TypedDict):
chunk_structure: str
parent_mode: str
preview: list[dict[str, Any]]
total_segments: int
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
result: ParentChildFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -4,8 +4,7 @@ import logging
import re
import threading
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
import pandas as pd
from flask import Flask, current_app
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class QAFormatPreviewDict(TypedDict):
chunk_structure: str
qa_preview: list[dict[str, Any]]
total_segments: int
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
result: QAFormatPreviewDict = {
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -1,7 +1,7 @@
import base64
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
:param query_type: query type
:return: rerank result
"""
docs = []
docs: list[MultimodalRerankInput] = []
doc_ids = set()
unique_documents = []
for document in documents:
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
docs.append(
MultimodalRerankInput(
content=document_file_base64,
content_type=document.metadata["doc_type"],
)
)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
docs.append(
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
unique_documents.append(document)
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
file_query_input = MultimodalRerankInput(
content=file_query,
content_type=DocType.IMAGE,
)
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -10,7 +10,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
from uuid import UUID
from zoneinfo import available_timezones
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
def extract_tenant_id(user: "Account | EndUser") -> str | None:
"""
Extract tenant_id from Account or EndUser object.
@ -164,7 +164,10 @@ def email(email):
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value: Any) -> str:
def uuid_value(value: str | UUID) -> str:
if isinstance(value, UUID):
return str(value)
if value == "":
return str(value)
@ -405,7 +408,7 @@ class TokenManager:
def generate_token(
cls,
token_type: str,
account: Optional["Account"] = None,
account: "Account | None" = None,
email: str | None = None,
additional_data: dict | None = None,
) -> str:
@ -465,9 +468,7 @@ class TokenManager:
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
key = cls._get_account_token_key(account_id, token_type)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token)

View File

@ -913,11 +913,7 @@ class TrialApp(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@property

View File

@ -148,18 +148,23 @@ class ExternalDatasetService:
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
"""
Return usage for an external knowledge API within a single tenant.
The caller already scopes access by tenant, so this query must do the
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
"""
count = (
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
ExternalKnowledgeBindings.tenant_id == tenant_id,
)
)
or 0
)
if count > 0:
return True, count
return False, 0
return count > 0, count
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:

View File

@ -1,3 +1,4 @@
from importlib import import_module
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
BedrockRetrievalApi,
ExternalApiTemplateApi,
ExternalApiTemplateListApi,
ExternalApiUseCheckApi,
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
external_controller = import_module("controllers.console.datasets.external")
def unwrap(func):
while hasattr(func, "__wrapped__"):
@ -44,10 +48,11 @@ def current_user():
@pytest.fixture(autouse=True)
def mock_auth(mocker, current_user):
mocker.patch(
"controllers.console.datasets.external.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
def mock_auth(monkeypatch, current_user):
monkeypatch.setattr(
external_controller,
"current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
method(api, "api-id")
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"external_knowledge_api_use_check",
return_value=(True, 2),
) as mock_use_check,
):
response, status = method(api, "api-id")
assert status == 200
assert response == {"is_using": True, "count": 2}
mock_use_check.assert_called_once_with("api-id", "tenant-1")
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
api = ExternalDatasetCreateApi()

View File

@ -4,9 +4,7 @@ from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
@ -198,42 +196,3 @@ class TestRequestContext:
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -39,6 +39,25 @@ class _FakeSession:
return None
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _patch_both(monkeypatch, module, session):
"""Patch both Session and sessionmaker on the module."""
monkeypatch.setattr(module, "Session", lambda _client: session)
monkeypatch.setattr(
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
)
@pytest.fixture
def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items():
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch):
def test_delete_drops_table(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.delete()
session.commit.assert_called_once()
session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):

View File

@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
session = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once()
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector.delete()
drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):

View File

@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is True
assert count == 3
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
"""
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is False
assert count == 0

View File

@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API has one binding."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is True
assert count == 1
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
@patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
"""Test API use check with multiple bindings."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is True
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API is not in use."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is False

View File

@ -2,7 +2,7 @@
import type { Area } from 'react-easy-crop'
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
import type { AvatarProps } from '@/app/components/base/avatar'
import type { AvatarProps } from '@/app/components/base/ui/avatar'
import type { ImageFile } from '@/types/app'
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
import * as React from 'react'
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { toast } from '@/app/components/base/ui/toast'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'

View File

@ -6,9 +6,9 @@ import {
import { Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useProviderContext } from '@/context/provider-context'
import { useRouter } from '@/next/navigation'
import { useLogout, useUserProfile } from '@/service/use-common'

View File

@ -10,9 +10,9 @@ import {
import * as React from 'react'
import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import { Avatar } from '@/app/components/base/ui/avatar'
import { toast } from '@/app/components/base/ui/toast'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'

View File

@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
import { useDebounce } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector } from '@/context/app-context'
import { SubjectType } from '@/models/access-control'
import { useSearchForWhiteListCandidates } from '@/service/access-control'
import { cn } from '@/utils/classnames'
import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Button from '../../base/button'
import Checkbox from '../../base/checkbox'
import Input from '../../base/input'

View File

@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { AccessMode } from '@/models/access-control'
import { useAppWhiteListSubjects } from '@/service/access-control'
import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Loading from '../../base/loading'
import Tooltip from '../../base/tooltip'
import AddMemberOrGroupDialog from './add-member-or-group-pop'

View File

@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
},
}))
vi.mock('@/app/components/base/avatar', () => ({
vi.mock('@/app/components/base/ui/avatar', () => ({
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
}))

View File

@ -7,11 +7,11 @@ import {
useCallback,
useMemo,
} from 'react'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
import { useStore as useAppStore } from '@/app/components/app/store'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types'
import {
AppSourceType,
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames'
import { formatBooleanInputs } from '@/utils/model-config'
import { Avatar } from '../../avatar'
import Chat from '../chat'
import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types'
import {
AppSourceType,
@ -23,7 +24,6 @@ import {
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames'
import { Avatar } from '../../avatar'
import Chat from '../chat'
import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { Avatar } from '../index'
import { Avatar } from '..'
describe('Avatar', () => {
describe('Rendering', () => {

View File

@ -36,7 +36,7 @@ function AvatarRoot({
return (
<BaseAvatar.Root
className={cn(
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600',
'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
avatarSizeClasses[size].root,
className,
)}
@ -53,7 +53,7 @@ function AvatarImage({
}: AvatarImageProps) {
return (
<BaseAvatar.Image
className={cn('absolute inset-0 size-full object-cover', className)}
className={cn('inset-0 absolute size-full object-cover', className)}
{...props}
/>
)

View File

@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { DatasetPermission } from '@/models/datasets'
import { cn } from '@/utils/classnames'

View File

@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import PremiumBadge from '@/app/components/base/premium-badge'
import ThemeSwitcher from '@/app/components/base/theme-switcher'
import { Avatar } from '@/app/components/base/ui/avatar'
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { IS_CLOUD_EDITION } from '@/config'

View File

@ -2,7 +2,7 @@
import type { InvitationResult } from '@/models/common'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import { Avatar } from '@/app/components/base/ui/avatar'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import { NUM_INFINITE } from '@/app/components/billing/config'
import { Plan } from '@/app/components/billing/type'

View File

@ -3,9 +3,9 @@ import type { FC } from 'react'
import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useMembers } from '@/service/use-common'
import { cn } from '@/utils/classnames'

View File

@ -3,7 +3,7 @@ import type { Member } from '@/models/common'
import { RiCloseCircleFill, RiErrorWarningFill } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import { Avatar } from '@/app/components/base/ui/avatar'
import { cn } from '@/utils/classnames'
type Props = {

View File

@ -4,8 +4,8 @@ import type { Recipient } from '@/app/components/workflow/nodes/human-input/type
import type { Member } from '@/models/common'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input'
import { Avatar } from '@/app/components/base/ui/avatar'
import { cn } from '@/utils/classnames'
const i18nPrefix = 'nodes.humanInput'

View File

@ -1,7 +1,7 @@
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import { Triangle } from '@/app/components/base/icons/src/public/education'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useAppContext } from '@/context/app-context'
import { useRouter } from '@/next/navigation'
import { useLogout } from '@/service/use-common'

View File

@ -1565,11 +1565,6 @@
"count": 1
}
},
"app/components/base/avatar/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2
}
},
"app/components/base/badge.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2

View File

@ -171,7 +171,7 @@ export default antfu(
},
{
name: 'dify/base-ui-primitives',
files: ['app/components/base/ui/**/*.tsx', 'app/components/base/avatar/**/*.tsx'],
files: ['app/components/base/ui/**/*.tsx'],
rules: {
'react-refresh/only-export-components': 'off',
},