mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 18:27:19 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
4680535ecd
@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Optional, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -17,7 +17,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
|
||||
|
||||
class DatasourceProviderApiEntityDict(TypedDict):
|
||||
|
||||
@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||
class RequestContext[SessionT: BaseSession, LifespanContextT]:
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
|
||||
@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -172,10 +172,10 @@ class ModelInstance:
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
@ -193,15 +193,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@ -216,15 +213,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
@ -241,15 +235,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
@ -261,14 +252,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@ -289,23 +277,20 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@ -320,17 +305,14 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
@ -342,14 +324,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
@ -361,14 +340,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
@ -381,18 +357,15 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@ -430,9 +403,8 @@ class ModelInstance:
|
||||
continue
|
||||
|
||||
try:
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
kwargs["credentials"] = lb_config.credentials
|
||||
return function(*args, **kwargs)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete(self):
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
|
||||
@ -6,7 +6,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id CHAR(36) PRIMARY KEY,
|
||||
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
with Session(self._engine) as session:
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class ParagraphFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
for content in chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
result: ParagraphFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"preview": preview,
|
||||
"total_segments": len(chunks),
|
||||
}
|
||||
return result
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
parent_mode: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {
|
||||
result: ParentChildFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -4,8 +4,7 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QAFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
qa_preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
return {
|
||||
result: QAFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||
"qa_preview": preview,
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
docs = []
|
||||
docs: list[MultimodalRerankInput] = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
document_file_dict = {
|
||||
"content": document_file_base64,
|
||||
"content_type": document.metadata["doc_type"],
|
||||
}
|
||||
docs.append(document_file_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document_file_base64,
|
||||
content_type=document.metadata["doc_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_text_dict = {
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
docs.append(document_text_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(
|
||||
{
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
unique_documents.append(document)
|
||||
|
||||
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
file_query_dict = {
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
file_query_input = MultimodalRerankInput(
|
||||
content=file_query,
|
||||
content_type=DocType.IMAGE,
|
||||
)
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -10,7 +10,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
def extract_tenant_id(user: "Account | EndUser") -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
@ -164,7 +164,10 @@ def email(email):
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value: Any) -> str:
|
||||
def uuid_value(value: str | UUID) -> str:
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
@ -405,7 +408,7 @@ class TokenManager:
|
||||
def generate_token(
|
||||
cls,
|
||||
token_type: str,
|
||||
account: Optional["Account"] = None,
|
||||
account: "Account | None" = None,
|
||||
email: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> str:
|
||||
@ -465,9 +468,7 @@ class TokenManager:
|
||||
return current_token
|
||||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(
|
||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
||||
):
|
||||
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(key, expiry_seconds, token)
|
||||
|
||||
@ -913,11 +913,7 @@ class TrialApp(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -148,18 +148,23 @@ class ExternalDatasetService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
|
||||
"""
|
||||
Return usage for an external knowledge API within a single tenant.
|
||||
|
||||
The caller already scopes access by tenant, so this query must do the
|
||||
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
|
||||
"""
|
||||
count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
|
||||
ExternalKnowledgeBindings.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
return False, 0
|
||||
return count > 0, count
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from importlib import import_module
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
|
||||
BedrockRetrievalApi,
|
||||
ExternalApiTemplateApi,
|
||||
ExternalApiTemplateListApi,
|
||||
ExternalApiUseCheckApi,
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
external_controller = import_module("controllers.console.datasets.external")
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@ -44,10 +48,11 @@ def current_user():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(mocker, current_user):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.external.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
def mock_auth(monkeypatch, current_user):
|
||||
monkeypatch.setattr(
|
||||
external_controller,
|
||||
"current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
|
||||
method(api, "api-id")
|
||||
|
||||
|
||||
class TestExternalApiUseCheckApi:
|
||||
def test_get_scopes_usage_check_to_current_tenant(self, app):
|
||||
api = ExternalApiUseCheckApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"external_knowledge_api_use_check",
|
||||
return_value=(True, 2),
|
||||
) as mock_use_check,
|
||||
):
|
||||
response, status = method(api, "api-id")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"is_using": True, "count": 2}
|
||||
mock_use_check.assert_called_once_with("api-id", "tenant-1")
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = ExternalDatasetCreateApi()
|
||||
|
||||
@ -4,9 +4,7 @@ from unittest.mock import Mock
|
||||
|
||||
from core.mcp.entities import (
|
||||
SUPPORTED_PROTOCOL_VERSIONS,
|
||||
LifespanContextT,
|
||||
RequestContext,
|
||||
SessionT,
|
||||
)
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||
@ -198,42 +196,3 @@ class TestRequestContext:
|
||||
assert "RequestContext" in repr_str
|
||||
assert "test-123" in repr_str
|
||||
assert "MockSession" in repr_str
|
||||
|
||||
|
||||
class TestTypeVariables:
|
||||
"""Test type variables defined in the module."""
|
||||
|
||||
def test_session_type_var(self):
|
||||
"""Test SessionT type variable."""
|
||||
|
||||
# Create a custom session class
|
||||
class CustomSession(BaseSession):
|
||||
pass
|
||||
|
||||
# Use in generic context
|
||||
def process_session(session: SessionT) -> SessionT:
|
||||
return session
|
||||
|
||||
mock_session = Mock(spec=CustomSession)
|
||||
result = process_session(mock_session)
|
||||
assert result == mock_session
|
||||
|
||||
def test_lifespan_context_type_var(self):
|
||||
"""Test LifespanContextT type variable."""
|
||||
|
||||
# Use in generic context
|
||||
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
||||
return context
|
||||
|
||||
# Test with different types
|
||||
str_context = "string-context"
|
||||
assert process_lifespan(str_context) == str_context
|
||||
|
||||
dict_context = {"key": "value"}
|
||||
assert process_lifespan(dict_context) == dict_context
|
||||
|
||||
class CustomContext:
|
||||
pass
|
||||
|
||||
custom_context = CustomContext()
|
||||
assert process_lifespan(custom_context) == custom_context
|
||||
|
||||
@ -39,6 +39,25 @@ class _FakeSession:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, session):
|
||||
"""Patch both Session and sessionmaker on the module."""
|
||||
monkeypatch.setattr(module, "Session", lambda _client: session)
|
||||
monkeypatch.setattr(
|
||||
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
session.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
||||
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
||||
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_commits_session(relyt_module, monkeypatch):
|
||||
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.delete()
|
||||
session.commit.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
|
||||
@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
session.begin.assert_called_once()
|
||||
sql = str(session.execute.call_args.args[0])
|
||||
assert "VECTOR<FLOAT>(3)" in sql
|
||||
assert "VEC_L2_DISTANCE" in sql
|
||||
session.commit.assert_called_once()
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
session.commit = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector.delete()
|
||||
drop_sql = str(session.execute.call_args.args[0])
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
|
||||
@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is True
|
||||
assert count == 3
|
||||
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
|
||||
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is False
|
||||
assert count == 0
|
||||
|
||||
@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
"""Test API use check when API has one binding."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 1
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is True
|
||||
assert count == 1
|
||||
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
|
||||
"""Test API use check with multiple bindings."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 10
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is True
|
||||
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
"""Test API use check when API is not in use."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is False
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import type { Area } from 'react-easy-crop'
|
||||
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import type { AvatarProps } from '@/app/components/base/avatar'
|
||||
import type { AvatarProps } from '@/app/components/base/ui/avatar'
|
||||
import type { ImageFile } from '@/types/app'
|
||||
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
|
||||
|
||||
@ -6,9 +6,9 @@ import {
|
||||
import { Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { useLogout, useUserProfile } from '@/service/use-common'
|
||||
|
||||
@ -10,9 +10,9 @@ import {
|
||||
import * as React from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
|
||||
@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useSelector } from '@/context/app-context'
|
||||
import { SubjectType } from '@/models/access-control'
|
||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Button from '../../base/button'
|
||||
import Checkbox from '../../base/checkbox'
|
||||
import Input from '../../base/input'
|
||||
|
||||
@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
|
||||
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Loading from '../../base/loading'
|
||||
import Tooltip from '../../base/tooltip'
|
||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||
|
||||
@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/avatar', () => ({
|
||||
vi.mock('@/app/components/base/ui/avatar', () => ({
|
||||
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||
}))
|
||||
|
||||
|
||||
@ -7,11 +7,11 @@ import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer } from '@/app/components/base/chat/utils'
|
||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||
|
||||
@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
|
||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||
|
||||
@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
|
||||
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
|
||||
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatBooleanInputs } from '@/utils/model-config'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
|
||||
@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
|
||||
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
|
||||
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
@ -23,7 +24,6 @@ import {
|
||||
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { Avatar } from '../index'
|
||||
import { Avatar } from '..'
|
||||
|
||||
describe('Avatar', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -36,7 +36,7 @@ function AvatarRoot({
|
||||
return (
|
||||
<BaseAvatar.Root
|
||||
className={cn(
|
||||
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600',
|
||||
'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
|
||||
avatarSizeClasses[size].root,
|
||||
className,
|
||||
)}
|
||||
@ -53,7 +53,7 @@ function AvatarImage({
|
||||
}: AvatarImageProps) {
|
||||
return (
|
||||
<BaseAvatar.Image
|
||||
className={cn('absolute inset-0 size-full object-cover', className)}
|
||||
className={cn('inset-0 absolute size-full object-cover', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { DatasetPermission } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import type { InvitationResult } from '@/models/common'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { NUM_INFINITE } from '@/app/components/billing/config'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
|
||||
@ -3,9 +3,9 @@ import type { FC } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import type { Member } from '@/models/common'
|
||||
import { RiCloseCircleFill, RiErrorWarningFill } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
|
||||
@ -4,8 +4,8 @@ import type { Recipient } from '@/app/components/workflow/nodes/human-input/type
|
||||
import type { Member } from '@/models/common'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const i18nPrefix = 'nodes.humanInput'
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { Triangle } from '@/app/components/base/icons/src/public/education'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { useLogout } from '@/service/use-common'
|
||||
|
||||
@ -1565,11 +1565,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/base/avatar/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/badge.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
|
||||
@ -171,7 +171,7 @@ export default antfu(
|
||||
},
|
||||
{
|
||||
name: 'dify/base-ui-primitives',
|
||||
files: ['app/components/base/ui/**/*.tsx', 'app/components/base/avatar/**/*.tsx'],
|
||||
files: ['app/components/base/ui/**/*.tsx'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': 'off',
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user