diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py new file mode 100644 index 0000000000..60683b0f21 --- /dev/null +++ b/api/core/app/apps/workflow_logging_callback.py @@ -0,0 +1,220 @@ +from typing import Optional + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) + +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} + + +class WorkflowLoggingCallback(WorkflowCallback): + def __init__(self) -> None: + self.current_node_id = None + + def on_event(self, event: GraphEngineEvent) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[GraphRunStartedEvent]", color="pink") + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started(event=event) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded(event=event) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed(event=event) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk(event=event) + elif isinstance(event, ParallelBranchRunStartedEvent): + self.on_workflow_parallel_started(event=event) + elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): + self.on_workflow_parallel_completed(event=event) + elif isinstance(event, IterationRunStartedEvent): + self.on_workflow_iteration_started(event=event) + elif isinstance(event, IterationRunNextEvent): + self.on_workflow_iteration_next(event=event) + elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): + self.on_workflow_iteration_completed(event=event) + else: + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") + + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: + """ + Workflow node execute started + """ + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") + + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: + """ + Workflow node execute succeeded + """ + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="green", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", + color="green", + ) + + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: + """ + Workflow node execute failed + """ + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color="red") + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="red", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="red", + ) + + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: + """ + Publish text chunk + """ + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id + self.print_text("\n[NodeRunStreamChunkEvent]") + self.print_text(f"Node ID: {route_node_state.node_id}") + + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) + + self.print_text(event.chunk_content, color="pink", end="") + + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: + """ + Publish parallel started + """ + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") + + def on_workflow_parallel_completed( + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + ) -> None: + """ + Publish parallel completed + """ + if isinstance(event, ParallelBranchRunSucceededEvent): + color = "blue" + elif isinstance(event, ParallelBranchRunFailedEvent): + color = "red" + + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) + self.print_text(f"Parallel ID: {event.parallel_id}", color=color) + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) + + if isinstance(event, ParallelBranchRunFailedEvent): + self.print_text(f"Error: {event.error}", color=color) + + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: + """ + Publish iteration started + """ + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: + """ + Publish iteration next + """ + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") + + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: + """ + Publish iteration completed + """ + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") + + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: + """Print text with highlighting and no end characters.""" + text_to_print = self._get_colored_text(text, color) if color else text + print(f"{text_to_print}", end=end) + + def _get_colored_text(self, text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py new file mode 100644 index 0000000000..652ef243b4 --- /dev/null +++ b/api/core/app/segments/__init__.py @@ -0,0 +1,49 @@ +from .segment_group import SegmentGroup +from .segments import ( + ArrayAnySegment, + ArraySegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayAnyVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) + +__all__ = [ + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", +] diff --git a/api/core/app/segments/exc.py b/api/core/app/segments/exc.py new file mode 100644 index 0000000000..5cf67c3bac --- /dev/null +++ b/api/core/app/segments/exc.py @@ -0,0 +1,2 @@ +class VariableError(ValueError): + pass diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py new file mode 100644 index 0000000000..40a69ed4eb --- /dev/null +++ b/api/core/app/segments/factory.py @@ -0,0 +1,76 @@ +from collections.abc import Mapping +from typing import Any + +from configs import dify_config + +from .exc import VariableError +from .segments import ( + ArrayAnySegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) + + +def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") + match value_type: + case SegmentType.STRING: + result = StringVariable.model_validate(mapping) + case SegmentType.SECRET: + result = SecretVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, int): + result = IntegerVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, float): + result = FloatVariable.model_validate(mapping) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise VariableError(f"invalid number value {value}") + case SegmentType.OBJECT if isinstance(value, dict): + result = ObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_STRING if isinstance(value, list): + result = ArrayStringVariable.model_validate(mapping) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + result = ArrayNumberVariable.model_validate(mapping) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f"not supported value type {value_type}") + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + return result + + +def build_segment(value: Any, /) -> Segment: + if value is None: + return NoneSegment() + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, list): + return ArrayAnySegment(value=value) + raise ValueError(f"not supported value {value}") diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py new file mode 100644 index 0000000000..3c4d7046f4 --- /dev/null +++ b/api/core/app/segments/parser.py @@ -0,0 +1,18 @@ +import re + +from core.workflow.entities.variable_pool import VariablePool + +from . import SegmentGroup, factory + +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") + + +def convert_template(*, template: str, variable_pool: VariablePool): + parts = re.split(VARIABLE_PATTERN, template) + segments = [] + for part in filter(lambda x: x, parts): + if "." in part and (value := variable_pool.get(part.split("."))): + segments.append(value) + else: + segments.append(factory.build_segment(part)) + return SegmentGroup(value=segments) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py new file mode 100644 index 0000000000..b363255b2c --- /dev/null +++ b/api/core/app/segments/segment_group.py @@ -0,0 +1,22 @@ +from .segments import Segment +from .types import SegmentType + + +class SegmentGroup(Segment): + value_type: SegmentType = SegmentType.GROUP + value: list[Segment] + + @property + def text(self): + return "".join([segment.text for segment in self.value]) + + @property + def log(self): + return "".join([segment.log for segment in self.value]) + + @property + def markdown(self): + return "".join([segment.markdown for segment in self.value]) + + def to_object(self): + return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py new file mode 100644 index 0000000000..b26b3c8291 --- /dev/null +++ b/api/core/app/segments/segments.py @@ -0,0 +1,126 @@ +import json +import sys +from collections.abc import Mapping, Sequence +from typing import Any + +from pydantic import BaseModel, ConfigDict, field_validator + +from .types import SegmentType + + +class Segment(BaseModel): + model_config = ConfigDict(frozen=True) + + value_type: SegmentType + value: Any + + @field_validator("value_type") + @classmethod + def validate_value_type(cls, value): + """ + This validator checks if the provided value is equal to the default value of the 'value_type' field. + If the value is different, a ValueError is raised. + """ + if value != cls.model_fields["value_type"].default: + raise ValueError("Cannot modify 'value_type'") + return value + + @property + def text(self) -> str: + return str(self.value) + + @property + def log(self) -> str: + return str(self.value) + + @property + def markdown(self) -> str: + return str(self.value) + + @property + def size(self) -> int: + return sys.getsizeof(self.value) + + def to_object(self) -> Any: + return self.value + + +class NoneSegment(Segment): + value_type: SegmentType = SegmentType.NONE + value: None = None + + @property + def text(self) -> str: + return "null" + + @property + def log(self) -> str: + return "null" + + @property + def markdown(self) -> str: + return "null" + + +class StringSegment(Segment): + value_type: SegmentType = SegmentType.STRING + value: str + + +class FloatSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: float + + +class IntegerSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: int + + +class ObjectSegment(Segment): + value_type: SegmentType = SegmentType.OBJECT + value: Mapping[str, Any] + + @property + def text(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False) + + @property + def log(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) + + @property + def markdown(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) + + +class ArraySegment(Segment): + @property + def markdown(self) -> str: + items = [] + for item in self.value: + if hasattr(item, "to_markdown"): + items.append(item.to_markdown()) + else: + items.append(str(item)) + return "\n".join(items) + + +class ArrayAnySegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_ANY + value: Sequence[Any] + + +class ArrayStringSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_STRING + value: Sequence[str] + + +class ArrayNumberSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_NUMBER + value: Sequence[float | int] + + +class ArrayObjectSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_OBJECT + value: Sequence[Mapping[str, Any]] diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py new file mode 100644 index 0000000000..9cf0856df5 --- /dev/null +++ b/api/core/app/segments/types.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class SegmentType(str, Enum): + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" + + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py new file mode 100644 index 0000000000..f0e403ab8d --- /dev/null +++ b/api/core/app/segments/variables.py @@ -0,0 +1,75 @@ +from pydantic import Field + +from core.helper import encrypter + +from .segments import ( + ArrayAnySegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType + + +class Variable(Segment): + """ + A variable is a segment that has a name. + """ + + id: str = Field( + default="", + description="Unique identity for variable. It's only used by environment variables now.", + ) + name: str + description: str = Field(default="", description="Description of the variable.") + + +class StringVariable(StringSegment, Variable): + pass + + +class FloatVariable(FloatSegment, Variable): + pass + + +class IntegerVariable(IntegerSegment, Variable): + pass + + +class ObjectVariable(ObjectSegment, Variable): + pass + + +class ArrayAnyVariable(ArrayAnySegment, Variable): + pass + + +class ArrayStringVariable(ArrayStringSegment, Variable): + pass + + +class ArrayNumberVariable(ArrayNumberSegment, Variable): + pass + + +class ArrayObjectVariable(ArrayObjectSegment, Variable): + pass + + +class SecretVariable(StringVariable): + value_type: SegmentType = SegmentType.SECRET + + @property + def log(self) -> str: + return encrypter.obfuscated_token(self.value) + + +class NoneVariable(NoneSegment, Variable): + value_type: SegmentType = SegmentType.NONE + value: None = None diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py new file mode 100644 index 0000000000..10bc9f6ed7 --- /dev/null +++ b/api/core/entities/message_entities.py @@ -0,0 +1,29 @@ +import enum +from typing import Any + +from pydantic import BaseModel + + +class PromptMessageFileType(enum.Enum): + IMAGE = "image" + + @staticmethod + def value_of(value): + for member in PromptMessageFileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class PromptMessageFile(BaseModel): + type: PromptMessageFileType + data: Any = None + + +class ImagePromptMessageFile(PromptMessageFile): + class DETAIL(enum.Enum): + LOW = "low" + HIGH = "high" + + type: PromptMessageFileType = PromptMessageFileType.IMAGE + detail: DETAIL = DETAIL.LOW diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py new file mode 100644 index 0000000000..5c4e694025 --- /dev/null +++ b/api/core/file/file_obj.py @@ -0,0 +1,145 @@ +import enum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.file.tool_file_parser import ToolFileParser +from core.file.upload_file_parser import UploadFileParser +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from extensions.ext_database import db + + +class FileExtraConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[dict[str, Any]] = None + + +class FileType(enum.Enum): + IMAGE = "image" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(enum.Enum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(enum.Enum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileVar(BaseModel): + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + url: Optional[str] = None # remote url + related_id: Optional[str] = None + extra_config: Optional[FileExtraConfig] = None + filename: Optional[str] = None + extension: Optional[str] = None + mime_type: Optional[str] = None + + def to_dict(self) -> dict: + return { + "__variant": self.__class__.__name__, + "tenant_id": self.tenant_id, + "type": self.type.value, + "transfer_method": self.transfer_method.value, + "url": self.preview_url, + "remote_url": self.url, + "related_id": self.related_id, + "filename": self.filename, + "extension": self.extension, + "mime_type": self.mime_type, + } + + def to_markdown(self) -> str: + """ + Convert file to markdown + :return: + """ + preview_url = self.preview_url + if self.type == FileType.IMAGE: + text = f'![{self.filename or ""}]({preview_url})' + else: + text = f"[{self.filename or preview_url}]({preview_url})" + + return text + + @property + def data(self) -> Optional[str]: + """ + Get image data, file signed url or base64 data + depending on config MULTIMODAL_SEND_IMAGE_FORMAT + :return: + """ + return self._get_data() + + @property + def preview_url(self) -> Optional[str]: + """ + Get signed preview url + :return: + """ + return self._get_data(force_url=True) + + @property + def prompt_message_content(self) -> ImagePromptMessageContent: + if self.type == FileType.IMAGE: + image_config = self.extra_config.image_config + + return ImagePromptMessageContent( + data=self.data, + detail=ImagePromptMessageContent.DETAIL.HIGH + if image_config.get("detail") == "high" + else ImagePromptMessageContent.DETAIL.LOW, + ) + + def _get_data(self, force_url: bool = False) -> Optional[str]: + from models.model import UploadFile + + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + upload_file = ( + db.session.query(UploadFile) + .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) + .first() + ) + + return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + extension = self.extension + # add sign url + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=extension + ) + + return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py new file mode 100644 index 0000000000..641686bd7c --- /dev/null +++ b/api/core/file/message_file_parser.py @@ -0,0 +1,243 @@ +import re +from collections.abc import Mapping, Sequence +from typing import Any, Union +from urllib.parse import parse_qs, urlparse + +import requests + +from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser, MessageFile, UploadFile +from services.file_service import IMAGE_EXTENSIONS + + +class MessageFileParser: + def __init__(self, tenant_id: str, app_id: str) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + + def validate_and_transform_files_arg( + self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] + ) -> list[FileVar]: + """ + validate and transform files arg + + :param files: + :param file_extra_config: + :param user: + :return: + """ + for file in files: + if not isinstance(file, dict): + raise ValueError("Invalid file format, must be dict") + if not file.get("type"): + raise ValueError("Missing file type") + FileType.value_of(file.get("type")) + if not file.get("transfer_method"): + raise ValueError("Missing file transfer method") + FileTransferMethod.value_of(file.get("transfer_method")) + if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: + if not file.get("url"): + raise ValueError("Missing file url") + if not file.get("url").startswith("http"): + raise ValueError("Invalid file url") + if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): + raise ValueError("Missing file upload_file_id") + if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): + raise ValueError("Missing file tool_file_id") + + # transform files to file objs + type_file_objs = self._to_file_objs(files, file_extra_config) + + # validate files + new_files = [] + for file_type, file_objs in type_file_objs.items(): + if file_type == FileType.IMAGE: + # parse and validate files + image_config = file_extra_config.image_config + + # check if image file feature is enabled + if not image_config: + continue + + # Validate number of files + if len(files) > image_config["number_limits"]: + raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") + + for file_obj in file_objs: + # Validate transfer method + if file_obj.transfer_method.value not in image_config["transfer_methods"]: + raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") + + # Validate file type + if file_obj.type != FileType.IMAGE: + raise ValueError(f"Invalid file type: {file_obj.type}") + + if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: + # check remote url valid and is image + result, error = self._check_image_remote_url(file_obj.url) + if result is False: + raise ValueError(error) + elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: + # get upload file from upload_file_id + upload_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.related_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + UploadFile.extension.in_(IMAGE_EXTENSIONS), + ) + .first() + ) + + # check upload file is belong to tenant and user + if not upload_file: + raise ValueError("Invalid upload file") + + new_files.append(file_obj) + + # return all file objs + return new_files + + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): + """ + transform message files + + :param files: + :param file_extra_config: + :return: + """ + # transform files to file objs + type_file_objs = self._to_file_objs(files, file_extra_config) + + # return all file objs + return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] + + def _to_file_objs( + self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig + ) -> dict[FileType, list[FileVar]]: + """ + transform files to file objs + + :param files: + :param file_extra_config: + :return: + """ + type_file_objs: dict[FileType, list[FileVar]] = { + # Currently only support image + FileType.IMAGE: [] + } + + if not files: + return type_file_objs + + # group by file type and convert file args or message files to FileObj + for file in files: + if isinstance(file, MessageFile): + if file.belongs_to == FileBelongsTo.ASSISTANT.value: + continue + + file_obj = self._to_file_obj(file, file_extra_config) + if file_obj.type not in type_file_objs: + continue + + type_file_objs[file_obj.type].append(file_obj) + + return type_file_objs + + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): + """ + transform file to file obj + + :param file: + :return: + """ + if isinstance(file, dict): + transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) + if transfer_method != FileTransferMethod.TOOL_FILE: + return FileVar( + tenant_id=self.tenant_id, + type=FileType.value_of(file.get("type")), + transfer_method=transfer_method, + url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config, + ) + return FileVar( + tenant_id=self.tenant_id, + type=FileType.value_of(file.get("type")), + transfer_method=transfer_method, + url=None, + related_id=file.get("tool_file_id"), + extra_config=file_extra_config, + ) + else: + return FileVar( + id=file.id, + tenant_id=self.tenant_id, + type=FileType.value_of(file.type), + transfer_method=FileTransferMethod.value_of(file.transfer_method), + url=file.url, + related_id=file.upload_file_id or None, + extra_config=file_extra_config, + ) + + def _check_image_remote_url(self, url): + try: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" + } + + def is_s3_presigned_url(url): + try: + parsed_url = urlparse(url) + if "amazonaws.com" not in parsed_url.netloc: + return False + query_params = parse_qs(parsed_url.query) + + def check_presign_v2(query_params): + required_params = ["Signature", "Expires"] + for param in required_params: + if param not in query_params: + return False + if not query_params["Expires"][0].isdigit(): + return False + signature = query_params["Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): + return False + + return True + + def check_presign_v4(query_params): + required_params = ["X-Amz-Signature", "X-Amz-Expires"] + for param in required_params: + if param not in query_params: + return False + if not query_params["X-Amz-Expires"][0].isdigit(): + return False + signature = query_params["X-Amz-Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): + return False + + return True + + return check_presign_v4(query_params) or check_presign_v2(query_params) + except Exception: + return False + + if is_s3_presigned_url(url): + response = requests.get(url, headers=headers, allow_redirects=True) + if response.status_code in {200, 304}: + return True, "" + + response = requests.head(url, headers=headers, allow_redirects=True) + if response.status_code in {200, 304}: + return True, "" + else: + return False, "URL does not exist." + except requests.RequestException as e: + return False, f"Error checking URL: {e}" diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py new file mode 100644 index 0000000000..a8c1fd4d02 --- /dev/null +++ b/api/core/file/upload_file_parser.py @@ -0,0 +1,79 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from typing import Optional + +from configs import dify_config +from extensions.ext_storage import storage + +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] +IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + + +class UploadFileParser: + @classmethod + def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: + if not upload_file: + return None + + if upload_file.extension not in IMAGE_EXTENSIONS: + return None + + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: + return cls.get_signed_temp_image_url(upload_file.id) + else: + # get image file base64 + try: + data = storage.load(upload_file.key) + except FileNotFoundError: + logging.error(f"File not found: {upload_file.key}") + return None + + encoded_string = base64.b64encode(data).decode("utf-8") + return f"data:{upload_file.mime_type};base64,{encoded_string}" + + @classmethod + def get_signed_temp_image_url(cls, upload_file_id) -> str: + """ + get signed url from upload file + + :param upload_file: UploadFile object + :return: + """ + base_url = dify_config.FILES_URL + image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @classmethod + def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + + :param upload_file_id: file id + :param timestamp: timestamp + :param nonce: nonce + :param sign: signature + :return: + """ + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg new file mode 100644 index 0000000000..2663a0f59e --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg @@ -0,0 +1,47 @@ + + + + diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py new file mode 100644 index 0000000000..4a605fbffe --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py @@ -0,0 +1,56 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class AddBaseRecordTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_id = tool_parameters.get("table_id", "") + if not table_id: + return self.create_text_message("Invalid parameter table_id") + + fields = tool_parameters.get("fields", "") + if not fields: + return self.create_text_message("Invalid parameter fields") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = {} + payload = {"fields": json.loads(fields)} + + try: + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to add base record, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml new file mode 100644 index 0000000000..3ce0154efd --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml @@ -0,0 +1,66 @@ +identity: + name: add_base_record + author: Doug Lea + label: + en_US: Add Base Record + zh_Hans: 在多维表格数据表中新增一条记录 +description: + human: + en_US: Add Base Record + zh_Hans: | + 在多维表格数据表中新增一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/create + llm: Add a new record in the multidimensional table data table. +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_id + type: string + required: true + label: + en_US: table_id + zh_Hans: 多维表格的数据表 + human_description: + en_US: bitable table id + zh_Hans: 多维表格数据表的唯一标识符 table_id + llm_description: bitable table id + form: llm + + - name: fields + type: string + required: true + label: + en_US: fields + zh_Hans: 数据表的列字段内容 + human_description: + en_US: The fields of the Base data table are the columns of the data table. + zh_Hans: | + 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} + 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 + 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure + llm_description: | + 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} + 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 + 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py new file mode 100644 index 0000000000..b05d700113 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py @@ -0,0 +1,48 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class CreateBaseTableTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + name = tool_parameters.get("name", "") + + fields = tool_parameters.get("fields", "") + if not fields: + return self.create_text_message("Invalid parameter fields") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = {} + payload = {"table": {"name": name, "fields": json.loads(fields)}} + + try: + res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to create base table, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml new file mode 100644 index 0000000000..48c46bec14 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml @@ -0,0 +1,106 @@ +identity: + name: create_base_table + author: Doug Lea + label: + en_US: Create Base Table + zh_Hans: 多维表格新增一个数据表 +description: + human: + en_US: Create base table + zh_Hans: | + 多维表格新增一个数据表,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table/create + llm: A tool for add a new data table to the multidimensional table. +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: name + type: string + required: false + label: + en_US: name + zh_Hans: name + human_description: + en_US: Multidimensional table data table name + zh_Hans: 多维表格数据表名称 + llm_description: Multidimensional table data table name + form: llm + + - name: fields + type: string + required: true + label: + en_US: fields + zh_Hans: fields + human_description: + en_US: Initial fields of the data table + zh_Hans: | + 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 + field_name:字段名; + type: 字段类型;可选值有 + 1:多行文本 + 2:数字 + 3:单选 + 4:多选 + 5:日期 + 7:复选框 + 11:人员 + 13:电话号码 + 15:超链接 + 17:附件 + 18:单向关联 + 20:公式 + 21:双向关联 + 22:地理位置 + 23:群组 + 1001:创建时间 + 1002:最后更新时间 + 1003:创建人 + 1004:修改人 + 1005:自动编号 + llm_description: | + 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 + field_name:字段名; + type: 字段类型;可选值有 + 1:多行文本 + 2:数字 + 3:单选 + 4:多选 + 5:日期 + 7:复选框 + 11:人员 + 13:电话号码 + 15:超链接 + 17:附件 + 18:单向关联 + 20:公式 + 21:双向关联 + 22:地理位置 + 23:群组 + 1001:创建时间 + 1002:最后更新时间 + 1003:创建人 + 1004:修改人 + 1005:自动编号 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py new file mode 100644 index 0000000000..862eb2171b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py @@ -0,0 +1,56 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class DeleteBaseRecordsTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_id = tool_parameters.get("table_id", "") + if not table_id: + return self.create_text_message("Invalid parameter table_id") + + record_ids = tool_parameters.get("record_ids", "") + if not record_ids: + return self.create_text_message("Invalid parameter record_ids") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = {} + payload = {"records": json.loads(record_ids)} + + try: + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml new file mode 100644 index 0000000000..595b287029 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml @@ -0,0 +1,60 @@ +identity: + name: delete_base_records + author: Doug Lea + label: + en_US: Delete Base Records + zh_Hans: 在多维表格数据表中删除多条记录 +description: + human: + en_US: Delete base records + zh_Hans: | + 该接口用于删除多维表格数据表中的多条记录,单次调用中最多删除 500 条记录。 + llm: A tool for delete multiple records in a multidimensional table data table, up to 500 records can be deleted in a single call. +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_id + type: string + required: true + label: + en_US: table_id + zh_Hans: 多维表格的数据表 + human_description: + en_US: bitable table id + zh_Hans: 多维表格数据表的唯一标识符 table_id + llm_description: bitable table id + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: record_ids + zh_Hans: record_ids + human_description: + en_US: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] + zh_Hans: 待删除的多条记录id列表,示例为 ["recwNXzPQv","recpCsf4ME"] + llm_description: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py new file mode 100644 index 0000000000..f512186303 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py @@ -0,0 +1,46 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class DeleteBaseTablesTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_ids = tool_parameters.get("table_ids", "") + if not table_ids: + return self.create_text_message("Invalid parameter table_ids") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = {} + payload = {"table_ids": json.loads(table_ids)} + + try: + res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml new file mode 100644 index 0000000000..5d72814363 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml @@ -0,0 +1,48 @@ +identity: + name: delete_base_tables + author: Doug Lea + label: + en_US: Delete Base Tables + zh_Hans: 删除多维表格中的数据表 +description: + human: + en_US: Delete base tables + zh_Hans: | + 删除多维表格中的数据表 + llm: A tool for deleting a data table in a multidimensional table +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_ids + type: string + required: true + label: + en_US: table_ids + zh_Hans: table_ids + human_description: + en_US: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] + zh_Hans: 待删除数据表的id列表,当前一次操作最多支持50个数据表,示例为 ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] + llm_description: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py new file mode 100644 index 0000000000..2ea61d0068 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py @@ -0,0 +1,48 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GetTenantAccessTokenTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + + app_id = tool_parameters.get("app_id", "") + if not app_id: + return self.create_text_message("Invalid parameter app_id") + + app_secret = tool_parameters.get("app_secret", "") + if not app_secret: + return self.create_text_message("Invalid parameter app_secret") + + headers = { + "Content-Type": "application/json", + } + params = {} + payload = {"app_id": app_id, "app_secret": app_secret} + + """ + { + "code": 0, + "msg": "ok", + "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", + "expire": 7200 + } + """ + try: + res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml new file mode 100644 index 0000000000..88acc27e06 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml @@ -0,0 +1,39 @@ +identity: + name: get_tenant_access_token + author: Doug Lea + label: + en_US: Get Tenant Access Token + zh_Hans: 获取飞书自建应用的 tenant_access_token +description: + human: + en_US: Get tenant access token + zh_Hans: | + 获取飞书自建应用的 tenant_access_token,响应体示例: + {"code":0,"msg":"ok","tenant_access_token":"t-caecc734c2e3328a62489fe0648c4b98779515d3","expire":7200} + tenant_access_token: 租户访问凭证; + expire: tenant_access_token 的过期时间,单位为秒; + llm: A tool for obtaining a tenant access token. The input parameters must include app_id and app_secret. +parameters: + - name: app_id + type: string + required: true + label: + en_US: app_id + zh_Hans: 应用唯一标识 + human_description: + en_US: app_id is the unique identifier of the Lark Open Platform application + zh_Hans: app_id 是飞书开放平台应用的唯一标识 + llm_description: app_id is the unique identifier of the Lark Open Platform application + form: llm + + - name: app_secret + type: secret-input + required: true + label: + en_US: app_secret + zh_Hans: 应用秘钥 + human_description: + en_US: app_secret is the secret key of the application + zh_Hans: app_secret 是应用的秘钥 + llm_description: app_secret is the secret key of the application + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py new file mode 100644 index 0000000000..e579d02f69 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py @@ -0,0 +1,65 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class ListBaseRecordsTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_id = tool_parameters.get("table_id", "") + if not table_id: + return self.create_text_message("Invalid parameter table_id") + + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + sort_condition = tool_parameters.get("sort_condition", "") + filter_condition = tool_parameters.get("filter_condition", "") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = { + "page_token": page_token, + "page_size": page_size, + } + + payload = {"automatic_fields": True} + if sort_condition: + payload["sort"] = json.loads(sort_condition) + if filter_condition: + payload["filter"] = json.loads(filter_condition) + + try: + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to list base records, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml new file mode 100644 index 0000000000..8647c880a6 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml @@ -0,0 +1,108 @@ +identity: + name: list_base_records + author: Doug Lea + label: + en_US: List Base Records + zh_Hans: 查询多维表格数据表中的现有记录 +description: + human: + en_US: List base records + zh_Hans: | + 查询多维表格数据表中的现有记录,单次最多查询 500 行记录,支持分页获取。 + llm: Query existing records in a multidimensional table data table. A maximum of 500 rows of records can be queried at a time, and paging retrieval is supported. +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_id + type: string + required: true + label: + en_US: table_id + zh_Hans: 多维表格的数据表 + human_description: + en_US: bitable table id + zh_Hans: 多维表格数据表的唯一标识符 table_id + llm_description: bitable table id + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: paging size + zh_Hans: 分页大小,默认值为 20,最大值为 100。 + llm_description: The default value of paging size is 20 and the maximum value is 100. + form: llm + + - name: sort_condition + type: string + required: false + label: + en_US: sort_condition + zh_Hans: 排序条件 + human_description: + en_US: sort condition + zh_Hans: | + 排序条件,格式为:[{"field_name":"多行文本","desc":true}]。 + field_name: 字段名称; + desc: 是否倒序排序; + llm_description: | + Sorting conditions, the format is: [{"field_name":"multi-line text","desc":true}]. + form: llm + + - name: filter_condition + type: string + required: false + label: + en_US: filter_condition + zh_Hans: 筛选条件 + human_description: + en_US: filter condition + zh_Hans: | + 筛选条件,格式为:{"conjunction":"and","conditions":[{"field_name":"字段1","operator":"is","value":["文本内容"]}]}。 + conjunction:条件逻辑连接词; + conditions:筛选条件集合; + field_name:筛选条件的左值,值为字段的名称; + operator:条件运算符; + value:目标值; + llm_description: | + The format of the filter condition is: {"conjunction":"and","conditions":[{"field_name":"Field 1","operator":"is","value":["text content"]}]}. + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py new file mode 100644 index 0000000000..4ec9a476bc --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py @@ -0,0 +1,47 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class ListBaseTablesTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = { + "page_token": page_token, + "page_size": page_size, + } + + try: + res = httpx.get(url.format(app_token=app_token), headers=headers, params=params, timeout=30) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml new file mode 100644 index 0000000000..9887124a28 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml @@ -0,0 +1,65 @@ +identity: + name: list_base_tables + author: Doug Lea + label: + en_US: List Base Tables + zh_Hans: 根据 app_token 获取多维表格下的所有数据表 +description: + human: + en_US: List base tables + zh_Hans: | + 根据 app_token 获取多维表格下的所有数据表 + llm: A tool for getting all data tables under a multidimensional table based on app_token. +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 + llm_description: | + Pagination token. If it is not filled in the first request, it means to start traversal from the beginning. + If there are more items in the pagination query result, a new page_token will be returned at the same time. + The page_token can be used to obtain the query result in the next traversal. + 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: paging size + zh_Hans: 分页大小,默认值为 20,最大值为 100。 + llm_description: The default value of paging size is 20 and the maximum value is 100. + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py new file mode 100644 index 0000000000..fb818f8380 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py @@ -0,0 +1,49 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class ReadBaseRecordTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_id = tool_parameters.get("table_id", "") + if not table_id: + return self.create_text_message("Invalid parameter table_id") + + record_id = tool_parameters.get("record_id", "") + if not record_id: + return self.create_text_message("Invalid parameter record_id") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + try: + res = httpx.get( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 + ) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to read base record, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml new file mode 100644 index 0000000000..400e9a1021 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml @@ -0,0 +1,60 @@ +identity: + name: read_base_record + author: Doug Lea + label: + en_US: Read Base Record + zh_Hans: 根据 record_id 的值检索多维表格数据表的记录 +description: + human: + en_US: Read base record + zh_Hans: | + 根据 record_id 的值检索多维表格数据表的记录 + llm: Retrieve records from a multidimensional table based on the value of record_id +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_id + type: string + required: true + label: + en_US: table_id + zh_Hans: 多维表格的数据表 + human_description: + en_US: bitable table id + zh_Hans: 多维表格数据表的唯一标识符 table_id + llm_description: bitable table id + form: llm + + - name: record_id + type: string + required: true + label: + en_US: record_id + zh_Hans: 单条记录的 id + human_description: + en_US: The id of a single record + zh_Hans: 单条记录的 id + llm_description: The id of a single record + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py new file mode 100644 index 0000000000..6d7e33f3ff --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py @@ -0,0 +1,60 @@ +import json +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class UpdateBaseRecordTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" + + access_token = tool_parameters.get("Authorization", "") + if not access_token: + return self.create_text_message("Invalid parameter access_token") + + app_token = tool_parameters.get("app_token", "") + if not app_token: + return self.create_text_message("Invalid parameter app_token") + + table_id = tool_parameters.get("table_id", "") + if not table_id: + return self.create_text_message("Invalid parameter table_id") + + record_id = tool_parameters.get("record_id", "") + if not record_id: + return self.create_text_message("Invalid parameter record_id") + + fields = tool_parameters.get("fields", "") + if not fields: + return self.create_text_message("Invalid parameter fields") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + params = {} + payload = {"fields": json.loads(fields)} + + try: + res = httpx.put( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) + res_json = res.json() + if res.is_success: + return self.create_text_message(text=json.dumps(res_json)) + else: + return self.create_text_message( + f"Failed to update base record, status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml new file mode 100644 index 0000000000..788798c4b3 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml @@ -0,0 +1,78 @@ +identity: + name: update_base_record + author: Doug Lea + label: + en_US: Update Base Record + zh_Hans: 更新多维表格数据表中的一条记录 +description: + human: + en_US: Update base record + zh_Hans: | + 更新多维表格数据表中的一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/update + llm: Update a record in a multidimensional table data table +parameters: + - name: Authorization + type: string + required: true + label: + en_US: token + zh_Hans: 凭证 + human_description: + en_US: API access token parameter, tenant_access_token or user_access_token + zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token + llm_description: API access token parameter, tenant_access_token or user_access_token + form: llm + + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: 多维表格 + human_description: + en_US: bitable app token + zh_Hans: 多维表格的唯一标识符 app_token + llm_description: bitable app token + form: llm + + - name: table_id + type: string + required: true + label: + en_US: table_id + zh_Hans: 多维表格的数据表 + human_description: + en_US: bitable table id + zh_Hans: 多维表格数据表的唯一标识符 table_id + llm_description: bitable table id + form: llm + + - name: record_id + type: string + required: true + label: + en_US: record_id + zh_Hans: 单条记录的 id + human_description: + en_US: The id of a single record + zh_Hans: 单条记录的 id + llm_description: The id of a single record + form: llm + + - name: fields + type: string + required: true + label: + en_US: fields + zh_Hans: 数据表的列字段内容 + human_description: + en_US: The fields of a multidimensional table data table, that is, the columns of the data table. + zh_Hans: | + 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} + 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 + 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure + llm_description: | + 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} + 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 + 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure + form: llm diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py new file mode 100644 index 0000000000..6f7610651c --- /dev/null +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -0,0 +1,71 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolParameter + + +class ToolParameterConverter: + @staticmethod + def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: + match parameter_type: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + return "string" + + case ToolParameter.ToolParameterType.BOOLEAN: + return "boolean" + + case ToolParameter.ToolParameterType.NUMBER: + return "number" + + case _: + raise ValueError(f"Unsupported parameter type {parameter_type}") + + @staticmethod + def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: + # convert tool parameter config to correct type + try: + match parameter_type: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int) | isinstance(value, float): + return value + elif isinstance(value, str) and value != "": + if "." in value: + return float(value) + else: + return int(value) + case ToolParameter.ToolParameterType.FILE: + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py new file mode 100644 index 0000000000..2a864dd7a8 --- /dev/null +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -0,0 +1,24 @@ +from abc import ABC +from typing import Optional + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + title: str + desc: Optional[str] = None + + +class BaseIterationNodeData(BaseNodeData): + start_node_id: Optional[str] = None + + +class BaseIterationState(BaseModel): + iteration_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py new file mode 100644 index 0000000000..7bfe45a13c --- /dev/null +++ b/api/core/workflow/nodes/base_node.py @@ -0,0 +1,117 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent, RunEvent + + +class BaseNode(ABC): + _node_data_cls: type[BaseNodeData] + _node_type: NodeType + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: + self.id = id + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id + + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required.") + + self.node_id = node_id + self.node_data = self._node_data_cls(**config.get("data", {})) + + @abstractmethod + def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + """ + Run node + :return: + """ + raise NotImplementedError + + def run(self) -> Generator[RunEvent | InNodeEvent, None, None]: + """ + Run node entry + :return: + """ + result = self._run() + + if isinstance(result, NodeRunResult): + yield RunCompletedEvent(run_result=result) + else: + yield from result + + @classmethod + def extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], config: dict + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param config: node config + :return: + """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, node_id=node_id, node_data=node_data + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return {} + + @property + def node_type(self) -> NodeType: + """ + Get node type + :return: + """ + return self._node_type diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event.py new file mode 100644 index 0000000000..276c13a6d4 --- /dev/null +++ b/api/core/workflow/nodes/event.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py new file mode 100644 index 0000000000..f8ab4e3132 --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -0,0 +1,343 @@ +import json +from copy import deepcopy +from random import randint +from typing import Any, Optional, Union +from urllib.parse import urlencode + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request.entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, + HttpRequestNodeTimeout, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class HttpExecutorResponse: + headers: dict[str, str] + response: httpx.Response + + def __init__(self, response: httpx.Response): + self.response = response + self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} + + @property + def is_file(self) -> bool: + """ + check if response is file + """ + content_type = self.get_content_type() + file_content_types = ["image", "audio", "video"] + + return any(v in content_type for v in file_content_types) + + def get_content_type(self) -> str: + return self.headers.get("content-type", "") + + def extract_file(self) -> tuple[str, bytes]: + """ + extract file from response if content type is file related + """ + if self.is_file: + return self.get_content_type(), self.body + + return "", b"" + + @property + def content(self) -> str: + if isinstance(self.response, httpx.Response): + return self.response.text + else: + raise ValueError(f"Invalid response type {type(self.response)}") + + @property + def body(self) -> bytes: + if isinstance(self.response, httpx.Response): + return self.response.content + else: + raise ValueError(f"Invalid response type {type(self.response)}") + + @property + def status_code(self) -> int: + if isinstance(self.response, httpx.Response): + return self.response.status_code + else: + raise ValueError(f"Invalid response type {type(self.response)}") + + @property + def size(self) -> int: + return len(self.body) + + @property + def readable_size(self) -> str: + if self.size < 1024: + return f"{self.size} bytes" + elif self.size < 1024 * 1024: + return f"{(self.size / 1024):.2f} KB" + else: + return f"{(self.size / 1024 / 1024):.2f} MB" + + +class HttpExecutor: + server_url: str + method: str + authorization: HttpRequestNodeAuthorization + params: dict[str, Any] + headers: dict[str, Any] + body: Union[None, str] + files: Union[None, dict[str, Any]] + boundary: str + variable_selectors: list[VariableSelector] + timeout: HttpRequestNodeTimeout + + def __init__( + self, + node_data: HttpRequestNodeData, + timeout: HttpRequestNodeTimeout, + variable_pool: Optional[VariablePool] = None, + ): + self.server_url = node_data.url + self.method = node_data.method + self.authorization = node_data.authorization + self.timeout = timeout + self.params = {} + self.headers = {} + self.body = None + self.files = None + + # init template + self.variable_selectors = [] + self._init_template(node_data, variable_pool) + + @staticmethod + def _is_json_body(body: HttpRequestNodeBody): + """ + check if body is json + """ + if body and body.type == "json" and body.data: + try: + json.loads(body.data) + return True + except: + return False + + return False + + @staticmethod + def _to_dict(convert_text: str): + """ + Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` + """ + kv_paris = convert_text.split("\n") + result = {} + for kv in kv_paris: + if not kv.strip(): + continue + + kv = kv.split(":", maxsplit=1) + if len(kv) == 1: + k, v = kv[0], "" + else: + k, v = kv + result[k.strip()] = v + return result + + def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): + # extract all template in url + self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) + + # extract all template in params + params, params_variable_selectors = self._format_template(node_data.params, variable_pool) + self.params = self._to_dict(params) + + # extract all template in headers + headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) + self.headers = self._to_dict(headers) + + # extract all template in body + body_data_variable_selectors = [] + if node_data.body: + # check if it's a valid JSON + is_valid_json = self._is_json_body(node_data.body) + + body_data = node_data.body.data or "" + if body_data: + body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) + + content_type_is_set = any(key.lower() == "content-type" for key in self.headers) + if node_data.body.type == "json" and not content_type_is_set: + self.headers["Content-Type"] = "application/json" + elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: + self.headers["Content-Type"] = "application/x-www-form-urlencoded" + + if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: + body = self._to_dict(body_data) + + if node_data.body.type == "form-data": + self.files = {k: ("", v) for k, v in body.items()} + random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) + self.boundary = f"----WebKitFormBoundary{random_str(16)}" + + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" + else: + self.body = urlencode(body) + elif node_data.body.type in {"json", "raw-text"}: + self.body = body_data + elif node_data.body.type == "none": + self.body = "" + + self.variable_selectors = ( + server_url_variable_selectors + + params_variable_selectors + + headers_variable_selectors + + body_data_variable_selectors + ) + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.authorization) + headers = deepcopy(self.headers) or {} + if self.authorization.type == "api-key": + if self.authorization.config is None: + raise ValueError("self.authorization config is required") + if authorization.config is None: + raise ValueError("authorization config is required") + + if self.authorization.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if self.authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.authorization.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key + + return headers + + def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse: + """ + validate the response + """ + if isinstance(response, httpx.Response): + executor_response = HttpExecutorResponse(response) + else: + raise ValueError(f"Invalid response type {type(response)}") + + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) + + return executor_response + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + kwargs = { + "url": self.server_url, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, + } + + if self.method in {"get", "head", "post", "put", "delete", "patch"}: + response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) + else: + raise ValueError(f"Invalid http method {self.method}") + return response + + def invoke(self) -> HttpExecutorResponse: + """ + invoke http request + """ + # assemble headers + headers = self._assembling_headers() + + # do http request + response = self._do_http_request(headers) + + # validate response + return self._validate_and_parse_response(response) + + def to_raw_request(self) -> str: + """ + convert to raw request + """ + server_url = self.server_url + if self.params: + server_url += f"?{urlencode(self.params)}" + + raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" + + headers = self._assembling_headers() + for k, v in headers.items(): + # get authorization header + if self.authorization.type == "api-key": + authorization_header = "Authorization" + if self.authorization.config and self.authorization.config.header: + authorization_header = self.authorization.config.header + + if k.lower() == authorization_header.lower(): + raw_request += f'{k}: {"*" * len(v)}\n' + continue + + raw_request += f"{k}: {v}\n" + + raw_request += "\n" + + # if files, use multipart/form-data with boundary + if self.files: + boundary = self.boundary + raw_request += f"--{boundary}" + for k, v in self.files.items(): + raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' + raw_request += f"{v[1]}\n" + raw_request += f"--{boundary}" + raw_request += "--" + else: + raw_request += self.body or "" + + return raw_request + + def _format_template( + self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False + ) -> tuple[str, list[VariableSelector]]: + """ + format template + """ + variable_template_parser = VariableTemplateParser(template=template) + variable_selectors = variable_template_parser.extract_variable_selectors() + + if variable_pool: + variable_value_mapping = {} + for variable_selector in variable_selectors: + variable = variable_pool.get_any(variable_selector.value_selector) + if variable is None: + raise ValueError(f"Variable {variable_selector.variable} not found") + if escape_quotes and isinstance(variable, str): + value = variable.replace('"', '\\"').replace("\n", "\\n") + else: + value = variable + variable_value_mapping[variable_selector.variable] = value + + return variable_template_parser.format(variable_value_mapping), variable_selectors + else: + return template, variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py new file mode 100644 index 0000000000..cd40819126 --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -0,0 +1,165 @@ +import logging +from collections.abc import Mapping, Sequence +from mimetypes import guess_extension +from os import path +from typing import Any, cast + +from configs import dify_config +from core.app.segments import parser +from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.http_request.entities import ( + HttpRequestNodeData, + HttpRequestNodeTimeout, +) +from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse +from models.workflow import WorkflowNodeExecutionStatus + +HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, +) + + +class HttpRequestNode(BaseNode): + _node_data_cls = HttpRequestNodeData + _node_type = NodeType.HTTP_REQUEST + + @classmethod + def get_default_config(cls, filters: dict | None = None) -> dict: + return { + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", + }, + "body": {"type": "none"}, + "timeout": { + **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + }, + }, + } + + def _run(self) -> NodeRunResult: + node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) + # TODO: Switch to use segment directly + if node_data.authorization.config and node_data.authorization.config.api_key: + node_data.authorization.config.api_key = parser.convert_template( + template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool + ).text + + # init http executor + http_executor = None + try: + http_executor = HttpExecutor( + node_data=node_data, + timeout=self._get_request_timeout(node_data), + variable_pool=self.graph_runtime_state.variable_pool, + ) + + # invoke http executor + response = http_executor.invoke() + except Exception as e: + process_data = {} + if http_executor: + process_data = { + "request": http_executor.to_raw_request(), + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data, + ) + + files = self.extract_files(http_executor.server_url, response) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "status_code": response.status_code, + "body": response.content if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_raw_request(), + }, + ) + + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + timeout = node_data.timeout + if timeout is None: + return HTTP_REQUEST_DEFAULT_TIMEOUT + + timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect + timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read + timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write + return timeout + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + try: + http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) + + variable_selectors = http_executor.variable_selectors + + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector + + return variable_mapping + except Exception as e: + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") + return {} + + def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: + """ + Extract files from response + """ + files = [] + mimetype, file_binary = response.extract_file() + + if mimetype: + # extract filename from url + filename = path.basename(url) + # extract extension if possible + extension = guess_extension(mimetype) or ".bin" + + tool_file = ToolFileManager.create_file_by_raw( + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + ) + + files.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file.id, + filename=filename, + extension=extension, + mime_type=mimetype, + ) + ) + + return files diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py new file mode 100644 index 0000000000..3d336b0b0b --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -0,0 +1,774 @@ +import json +from collections.abc import Generator, Mapping, Sequence +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Optional, cast + +from pydantic import BaseModel + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, +) +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.model import Conversation +from models.provider import Provider, ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +if TYPE_CHECKING: + from core.file.file_obj import FileVar + + +class ModelInvokeCompleted(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: Optional[str] = None + + +class LLMNode(BaseNode): + _node_data_cls = LLMNodeData + _node_type = NodeType.LLM + + def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]: + """ + Run node + :return: + """ + node_data = cast(LLMNodeData, deepcopy(self.node_data)) + variable_pool = self.graph_runtime_state.variable_pool + + node_inputs = None + process_data = None + + try: + # init messages template + node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data, variable_pool) + + # fetch jinja2 inputs + jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + + # merge inputs + inputs.update(jinja_inputs) + + node_inputs = {} + + # fetch files + files = self._fetch_files(node_data, variable_pool) + + if files: + node_inputs["#files#"] = [file.to_dict() for file in files] + + # fetch context value + generator = self._fetch_context(node_data, variable_pool) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield event + + if context: + node_inputs["#context#"] = context # type: ignore + + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data.model) + + # fetch memory + memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, + query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config, + ) + + process_data = { + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "model_provider": model_config.provider, + "model_name": model_config.model, + } + + # handle invoke result + generator = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + result_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + except Exception as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return + + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, + ) + ) + + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + """ + Invoke large language model + :param node_data_model: node data model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data_model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + generator = self._handle_invoke_result(invoke_result=invoke_result) + + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompleted): + usage = event.usage + + # deduct quota + self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + if isinstance(invoke_result, LLMResult): + return + + model = None + prompt_messages: list[PromptMessage] = [] + full_text = "" + usage = None + finish_reason = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not finish_reason and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + + if not usage: + usage = LLMUsage.empty_usage() + + yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) + + def _transform_chat_messages( + self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + """ + Transform chat messages + + :param messages: chat messages + :return: + """ + + if isinstance(messages, LLMNodeCompletionModelPromptTemplate): + if messages.edition_type == "jinja2" and messages.jinja2_text: + messages.text = messages.jinja2_text + + return messages + + for message in messages: + if message.edition_type == "jinja2" and message.jinja2_text: + message.text = message.jinja2_text + + return messages + + def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch jinja inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + variables = {} + + if not node_data.prompt_config: + return variables + + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable = variable_selector.variable + value = variable_pool.get_any(variable_selector.value_selector) + + def parse_dict(d: dict) -> str: + """ + Parse dict into string + """ + # check if it's a context structure + if "metadata" in d and "_source" in d["metadata"] and "content" in d: + return d["content"] + + # else, parse the dict + try: + return json.dumps(d, ensure_ascii=False) + except Exception: + return str(d) + + if isinstance(value, str): + value = value + elif isinstance(value, list): + result = "" + for item in value: + if isinstance(item, dict): + result += parse_dict(item) + elif isinstance(item, str): + result += item + elif isinstance(item, int | float): + result += str(item) + else: + result += str(item) + result += "\n" + value = result.strip() + elif isinstance(value, dict): + value = parse_dict(value) + elif isinstance(value, int | float): + value = str(value) + else: + value = str(value) + + variables[variable] = value + + return variables + + def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + inputs = {} + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list): + for prompt in prompt_template: + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + elif isinstance(prompt_template, CompletionModelPromptTemplate): + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + + for variable_selector in variable_selectors: + variable_value = variable_pool.get_any(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f"Variable {variable_selector.variable} not found") + + inputs[variable_selector.variable] = variable_value + + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() + for variable_selector in query_variable_selectors: + variable_value = variable_pool.get_any(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f"Variable {variable_selector.variable} not found") + + inputs[variable_selector.variable] = variable_value + + return inputs + + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: + """ + Fetch files + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.vision.enabled: + return [] + + files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) + if not files: + return [] + + return files + + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: + """ + Fetch context + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.context.enabled: + return + + if not node_data.context.variable_selector: + return + + context_value = variable_pool.get_any(node_data.context.variable_selector) + if context_value: + if isinstance(context_value, str): + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) + elif isinstance(context_value, list): + context_str = "" + original_retriever_resource = [] + for item in context_value: + if isinstance(item, str): + context_str += item + "\n" + else: + if "content" not in item: + raise ValueError(f"Invalid context structure: {item}") + + context_str += item["content"] + "\n" + + retriever_resource = self._convert_to_original_retriever_resource(item) + if retriever_resource: + original_retriever_resource.append(retriever_resource) + + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, context=context_str.strip() + ) + + def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + """ + Convert to original retriever resource, temp. + :param context_dict: context dict + :return: + """ + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) + + source = { + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + } + + return source + + return None + + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data_model: node data model + :return: + """ + model_name = node_data_model.name + provider_name = node_data_model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data_model.completion_params + stop = [] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + # get model mode + model_mode = node_data_model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data_memory: node data memory + :param variable_pool: variable pool + :return: + """ + if not node_data_memory: + return None + + # get conversation id + conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) + if conversation_id is None: + return None + + # get conversation + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) + + if not conversation: + return None + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + return memory + + def _fetch_prompt_messages( + self, + node_data: LLMNodeData, + query: Optional[str], + query_prompt_template: Optional[str], + inputs: dict[str, str], + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param query: query + :param query_prompt_template: query prompt template + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_messages = prompt_transform.get_prompt( + prompt_template=node_data.prompt_template, + inputs=inputs, + query=query or "", + files=files, + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config, + query_prompt_template=query_prompt_template, + ) + stop = model_config.stop + + vision_enabled = node_data.vision.enabled + vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if prompt_message.is_empty(): + continue + + if not isinstance(prompt_message.content, str): + prompt_message_content = [] + for content_item in prompt_message.content: + if ( + vision_enabled + and content_item.type == PromptMessageContentType.IMAGE + and isinstance(content_item, ImagePromptMessageContent) + ): + # Override vision config if LLM node has vision config + if vision_detail: + content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) + prompt_message_content.append(content_item) + elif content_item.type == PromptMessageContentType.TEXT: + prompt_message_content.append(content_item) + + if len(prompt_message_content) > 1: + prompt_message.content = prompt_message_content + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): + prompt_message.content = prompt_message_content[0].data + + filtered_prompt_messages.append(prompt_message) + + if not filtered_prompt_messages: + raise ValueError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop + + @classmethod + def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param tenant_id: tenant id + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if "gpt-4" in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) + db.session.commit() + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list): + for prompt in prompt_template: + if prompt.edition_type != "jinja2": + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + else: + if prompt_template.edition_type != "jinja2": + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() + for variable_selector in query_variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping["#context#"] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] + + if node_data.memory: + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] + + if node_data.prompt_config: + enable_jinja = False + + if isinstance(prompt_template, list): + for prompt in prompt_template: + if prompt.edition_type == "jinja2": + enable_jinja = True + break + else: + if prompt_template.edition_type == "jinja2": + enable_jinja = True + + if enable_jinja: + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + + return variable_mapping + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} + ] + }, + "completion_model": { + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", + }, + "stop": ["Human:"], + }, + } + }, + } diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py new file mode 100644 index 0000000000..279a6cdbc3 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -0,0 +1,56 @@ +import pytest + +from core.tools.entities.tool_entities import ToolParameter +from core.tools.utils.tool_parameter_converter import ToolParameterConverter + + +def test_get_parameter_type(): + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" + with pytest.raises(ValueError): + ToolParameterConverter.get_parameter_type("unsupported_type") + + +def test_cast_parameter_by_type(): + # string + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" + + # secret input + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" + + # select + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" + + # boolean + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] + for value in true_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True + + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] + for value in false_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False + + # number + assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None + + # unknown + assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/web/app/components/app/configuration/config-var/select-type-item/style.module.css b/web/app/components/app/configuration/config-var/select-type-item/style.module.css new file mode 100644 index 0000000000..8ff716d58b --- /dev/null +++ b/web/app/components/app/configuration/config-var/select-type-item/style.module.css @@ -0,0 +1,40 @@ +.item { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + height: 58px; + width: 98px; + border-radius: 8px; + border: 1px solid #EAECF0; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); + background-color: #fff; + cursor: pointer; +} + +.item:not(.selected):hover { + border-color: #B2CCFF; + background-color: #F5F8FF; + box-shadow: 0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06); +} + +.item.selected { + color: #155EEF; + border-color: #528BFF; + background-color: #F5F8FF; + box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); +} + +.text { + font-size: 13px; + color: #667085; + font-weight: 500; +} + +.item.selected .text { + color: #155EEF; +} + +.item:not(.selected):hover { + color: #344054; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-vision/radio-group/index.tsx b/web/app/components/app/configuration/config-vision/radio-group/index.tsx new file mode 100644 index 0000000000..a1cfb06e6a --- /dev/null +++ b/web/app/components/app/configuration/config-vision/radio-group/index.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import s from './style.module.css' +import cn from '@/utils/classnames' + +type OPTION = { + label: string + value: any +} + +type Props = { + className?: string + options: OPTION[] + value: any + onChange: (value: any) => void +} + +const RadioGroup: FC = ({ + className = '', + options, + value, + onChange, +}) => { + return ( +
+ {options.map(item => ( +
onChange(item.value)} + > +
+
{item.label}
+
+ ))} +
+ ) +} +export default React.memo(RadioGroup) diff --git a/web/app/components/app/configuration/config-vision/radio-group/style.module.css b/web/app/components/app/configuration/config-vision/radio-group/style.module.css new file mode 100644 index 0000000000..22c29c6a42 --- /dev/null +++ b/web/app/components/app/configuration/config-vision/radio-group/style.module.css @@ -0,0 +1,24 @@ +.item { + @apply grow flex items-center h-8 px-2.5 rounded-lg bg-gray-25 border border-gray-100 cursor-pointer space-x-2; +} + +.item:hover { + background-color: #ffffff; + border-color: #B2CCFF; + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); +} + +.item.checked { + background-color: #ffffff; + border-color: #528BFF; + box-shadow: 0px 1px 2px 0px rgba(16, 24, 40, 0.06), 0px 1px 3px 0px rgba(16, 24, 40, 0.10); +} + +.radio { + @apply w-4 h-4 border-[2px] border-gray-200 rounded-full; +} + +.item.checked .radio { + border-width: 5px; + border-color: #155eef; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-voice/param-config-content.tsx b/web/app/components/app/configuration/config-voice/param-config-content.tsx new file mode 100644 index 0000000000..4e70bdda21 --- /dev/null +++ b/web/app/components/app/configuration/config-voice/param-config-content.tsx @@ -0,0 +1,220 @@ +'use client' +import useSWR from 'swr' +import type { FC } from 'react' +import { useContext } from 'use-context-selector' +import React, { Fragment } from 'react' +import { usePathname } from 'next/navigation' +import { useTranslation } from 'react-i18next' +import { Listbox, Transition } from '@headlessui/react' +import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' +import classNames from '@/utils/classnames' +import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' +import type { Item } from '@/app/components/base/select' +import ConfigContext from '@/context/debug-configuration' +import { fetchAppVoices } from '@/service/apps' +import Tooltip from '@/app/components/base/tooltip' +import { languages } from '@/i18n/language' +import { TtsAutoPlay } from '@/types/app' +const VoiceParamConfig: FC = () => { + const { t } = useTranslation() + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + + const { + textToSpeechConfig, + setTextToSpeechConfig, + } = useContext(ConfigContext) + + let languageItem = languages.find(item => item.value === textToSpeechConfig.language) + const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') + if (languages && !languageItem && languages.length > 0) + languageItem = languages[0] + const language = languageItem?.value + const voiceItems = useSWR({ appId, language }, fetchAppVoices).data + let voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) + if (voiceItems && !voiceItem && voiceItems.length > 0) + voiceItem = voiceItems[0] + + const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') + + return ( +
+
+
{t('appDebug.voice.voiceSettings.title')}
+
+
+
+
{t('appDebug.voice.voiceSettings.language')}
+ + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } + /> +
+ { + setTextToSpeechConfig({ + ...textToSpeechConfig, + language: String(value.value), + }) + }} + > +
+ + + {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} + + + + + + + + {languages.map((item: Item) => ( + + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' + }` + } + value={item} + disabled={false} + > + {({ /* active, */ selected }) => ( + <> + {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} + {(selected || item.value === textToSpeechConfig.language) && ( + + + )} + + )} + + ))} + + +
+
+
+
+
{t('appDebug.voice.voiceSettings.voice')}
+ { + if (!value.value) + return + setTextToSpeechConfig({ + ...textToSpeechConfig, + voice: String(value.value), + }) + }} + > +
+ + {voiceItem?.name ?? localVoicePlaceholder} + + + + + + + {voiceItems?.map((item: Item) => ( + + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' + }` + } + value={item} + disabled={false} + > + {({ /* active, */ selected }) => ( + <> + {item.name} + {(selected || item.value === textToSpeechConfig.voice) && ( + + + )} + + )} + + ))} + + +
+
+
+
+
{t('appDebug.voice.voiceSettings.autoPlay')}
+ { + setTextToSpeechConfig({ + ...textToSpeechConfig, + autoPlay: value, + }) + }} + /> +
+
+
+ + ) +} + +export default React.memo(VoiceParamConfig) diff --git a/web/app/components/app/configuration/config-voice/param-config.tsx b/web/app/components/app/configuration/config-voice/param-config.tsx new file mode 100644 index 0000000000..f1e2475495 --- /dev/null +++ b/web/app/components/app/configuration/config-voice/param-config.tsx @@ -0,0 +1,41 @@ +'use client' +import type { FC } from 'react' +import { memo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import VoiceParamConfig from './param-config-content' +import cn from '@/utils/classnames' +import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' + +const ParamsConfig: FC = () => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + + return ( + + setOpen(v => !v)}> +
+ +
{t('appDebug.voice.settings')}
+
+
+ +
+ +
+
+
+ ) +} +export default memo(ParamsConfig) diff --git a/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx b/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx new file mode 100644 index 0000000000..eb3edc7593 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { PlusIcon } from '@heroicons/react/24/solid' + +export type IAddFeatureBtnProps = { + toBottomHeight: number + onClick: () => void +} + +const ITEM_HEIGHT = 48 + +const AddFeatureBtn: FC = ({ + toBottomHeight, + onClick, +}) => { + const { t } = useTranslation() + return ( +
+
+ +
{t('appDebug.operation.addFeature')}
+
+
+ ) +} +export default React.memo(AddFeatureBtn) diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx new file mode 100644 index 0000000000..18623c11c3 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx @@ -0,0 +1,52 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import s from './style.module.css' +import cn from '@/utils/classnames' +import Switch from '@/app/components/base/switch' + +export type IFeatureItemProps = { + icon: React.ReactNode + previewImgClassName?: string + title: string + description: string + value: boolean + onChange: (value: boolean) => void +} + +const FeatureItem: FC = ({ + icon, + previewImgClassName, + title, + description, + value, + onChange, +}) => { + return ( +
+
+ {/* icon */} +
+ {icon} +
+
+
{title}
+
{description}
+
+
+ + + { + previewImgClassName && ( +
+
) + } +
+ ) +} +export default React.memo(FeatureItem) diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.png new file mode 100644 index 0000000000..cc0847c942 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.svg b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.svg new file mode 100644 index 0000000000..82fb182a7a --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citation.svg @@ -0,0 +1,150 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citations-and-attributions-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citations-and-attributions-preview@2x.png new file mode 100644 index 0000000000..ef066204ca Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/citations-and-attributions-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/conversation-opener-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/conversation-opener-preview@2x.png new file mode 100644 index 0000000000..15639d500d Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/conversation-opener-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this-preview@2x.png new file mode 100644 index 0000000000..62671c5889 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.png new file mode 100644 index 0000000000..9eb1c3a1ce Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.svg b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.svg new file mode 100644 index 0000000000..6ccc84e70f --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/more-like-this.svg @@ -0,0 +1,188 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/next-question-suggestion-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/next-question-suggestion-preview@2x.png new file mode 100644 index 0000000000..758708ff15 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/next-question-suggestion-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-statement.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-statement.png new file mode 100644 index 0000000000..51e1bf395e Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-statement.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-suggestion-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-suggestion-preview@2x.png new file mode 100644 index 0000000000..8bb4add322 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/opening-suggestion-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text-preview@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text-preview@2x.png new file mode 100644 index 0000000000..68df3983dc Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text-preview@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.png new file mode 100644 index 0000000000..c951d0c418 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.svg b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.svg new file mode 100644 index 0000000000..029b92fee4 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/speech-to-text.svg @@ -0,0 +1,100 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.png new file mode 100644 index 0000000000..bee4be0acf Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.svg b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.svg new file mode 100644 index 0000000000..c0102b7e08 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/suggested-questions-after-answer.svg @@ -0,0 +1,163 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-assistant@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-assistant@2x.png new file mode 100644 index 0000000000..91396e72c7 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-assistant@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-completion@2x.png b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-completion@2x.png new file mode 100644 index 0000000000..7558e78bd9 Binary files /dev/null and b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/preview-imgs/text-to-audio-preview-completion@2x.png differ diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/style.module.css b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/style.module.css new file mode 100644 index 0000000000..32dc6be3ba --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/style.module.css @@ -0,0 +1,41 @@ +.preview { + display: none; + position: absolute; + top: 0; + left: 100%; + transform: translate(32px, -54px); + width: 280px; + height: 360px; + background: center center no-repeat; + background-size: contain; + border-radius: 8px; +} + +.wrap:hover .preview { + display: block; +} + +.openingStatementPreview { + background-image: url(./preview-imgs/opening-statement.png); +} + +.suggestedQuestionsAfterAnswerPreview { + background-image: url(./preview-imgs/suggested-questions-after-answer.png); +} + +.moreLikeThisPreview { + background-image: url(./preview-imgs/more-like-this.png); +} + +.speechToTextPreview { + background-image: url(./preview-imgs/speech-to-text.png); +} + +.textToSpeechPreview { + @apply shadow-lg rounded-lg; + background-image: url(./preview-imgs/text-to-audio-preview-assistant@2x.png); +} + +.citationPreview { + background-image: url(./preview-imgs/citation.png); +} diff --git a/web/app/components/app/configuration/config/feature/choose-feature/index.tsx b/web/app/components/app/configuration/config/feature/choose-feature/index.tsx new file mode 100644 index 0000000000..8364f9529d --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/index.tsx @@ -0,0 +1,172 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import FeatureGroup from '../feature-group' +import MoreLikeThisIcon from '../../../base/icons/more-like-this-icon' +import FeatureItem from './feature-item' +import Modal from '@/app/components/base/modal' +import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' +import { Microphone01, Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import { Citations } from '@/app/components/base/icons/src/vender/solid/editor' +import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' +import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' +type IConfig = { + openingStatement: boolean + moreLikeThis: boolean + suggestedQuestionsAfterAnswer: boolean + speechToText: boolean + textToSpeech: boolean + citation: boolean + moderation: boolean + annotation: boolean +} + +export type IChooseFeatureProps = { + isShow: boolean + onClose: () => void + config: IConfig + isChatApp: boolean + onChange: (key: string, value: boolean) => void + showTextToSpeechItem?: boolean + showSpeechToTextItem?: boolean +} + +const OpeningStatementIcon = ( + + + +) + +const ChooseFeature: FC = ({ + isShow, + onClose, + isChatApp, + config, + onChange, + showTextToSpeechItem, + showSpeechToTextItem, +}) => { + const { t } = useTranslation() + return ( + +
+ {/* Chat Feature */} + {isChatApp && ( + + <> + onChange('openingStatement', value)} + /> + } + previewImgClassName='suggestedQuestionsAfterAnswerPreview' + title={t('appDebug.feature.suggestedQuestionsAfterAnswer.title')} + description={t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} + value={config.suggestedQuestionsAfterAnswer} + onChange={value => onChange('suggestedQuestionsAfterAnswer', value)} + /> + { + showTextToSpeechItem && ( + } + previewImgClassName='textToSpeechPreview' + title={t('appDebug.feature.textToSpeech.title')} + description={t('appDebug.feature.textToSpeech.description')} + value={config.textToSpeech} + onChange={value => onChange('textToSpeech', value)} + /> + ) + } + { + showSpeechToTextItem && ( + } + previewImgClassName='speechToTextPreview' + title={t('appDebug.feature.speechToText.title')} + description={t('appDebug.feature.speechToText.description')} + value={config.speechToText} + onChange={value => onChange('speechToText', value)} + /> + ) + } + } + previewImgClassName='citationPreview' + title={t('appDebug.feature.citation.title')} + description={t('appDebug.feature.citation.description')} + value={config.citation} + onChange={value => onChange('citation', value)} + /> + + + )} + + {/* Text Generation Feature */} + {!isChatApp && ( + + <> + } + previewImgClassName='moreLikeThisPreview' + title={t('appDebug.feature.moreLikeThis.title')} + description={t('appDebug.feature.moreLikeThis.description')} + value={config.moreLikeThis} + onChange={value => onChange('moreLikeThis', value)} + /> + { + showTextToSpeechItem && ( + } + previewImgClassName='textToSpeechPreview' + title={t('appDebug.feature.textToSpeech.title')} + description={t('appDebug.feature.textToSpeech.description')} + value={config.textToSpeech} + onChange={value => onChange('textToSpeech', value)} + /> + ) + } + + + )} + + <> + } + previewImgClassName='' + title={t('appDebug.feature.moderation.title')} + description={t('appDebug.feature.moderation.description')} + value={config.moderation} + onChange={value => onChange('moderation', value)} + /> + {isChatApp && ( + } + title={t('appDebug.feature.annotation.title')} + description={t('appDebug.feature.annotation.description')} + value={config.annotation} + onChange={value => onChange('annotation', value)} + /> + )} + + +
+
+ ) +} +export default React.memo(ChooseFeature) diff --git a/web/app/components/app/configuration/config/feature/feature-group/index.tsx b/web/app/components/app/configuration/config/feature/feature-group/index.tsx new file mode 100644 index 0000000000..a4b27f18d4 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/feature-group/index.tsx @@ -0,0 +1,31 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import GroupName from '@/app/components/app/configuration/base/group-name' + +export type IFeatureGroupProps = { + title: string + description?: string + children: React.ReactNode +} + +const FeatureGroup: FC = ({ + title, + description, + children, +}) => { + return ( +
+
+ + {description && ( +
{description}
+ )} +
+
+ {children} +
+
+ ) +} +export default React.memo(FeatureGroup) diff --git a/web/app/components/app/configuration/features/chat-group/citation/index.tsx b/web/app/components/app/configuration/features/chat-group/citation/index.tsx new file mode 100644 index 0000000000..4003b68cd3 --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/citation/index.tsx @@ -0,0 +1,25 @@ +'use client' +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import { Citations } from '@/app/components/base/icons/src/vender/solid/editor' + +const Citation: FC = () => { + const { t } = useTranslation() + + return ( + +
{t('appDebug.feature.citation.title')}
+ + } + headerIcon={} + headerRight={ +
{t('appDebug.feature.citation.resDes')}
+ } + noBodySpacing + /> + ) +} +export default React.memo(Citation) diff --git a/web/app/components/app/configuration/features/chat-group/index.tsx b/web/app/components/app/configuration/features/chat-group/index.tsx new file mode 100644 index 0000000000..fd3cfa3a68 --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/index.tsx @@ -0,0 +1,65 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import GroupName from '../../base/group-name' +import type { IOpeningStatementProps } from './opening-statement' +import OpeningStatement from './opening-statement' +import SuggestedQuestionsAfterAnswer from './suggested-questions-after-answer' +import SpeechToText from './speech-to-text' +import TextToSpeech from './text-to-speech' +import Citation from './citation' +/* +* Include +* 1. Conversation Opener +* 2. Opening Suggestion +* 3. Next question suggestion +*/ +type ChatGroupProps = { + isShowOpeningStatement: boolean + openingStatementConfig: IOpeningStatementProps + isShowSuggestedQuestionsAfterAnswer: boolean + isShowSpeechText: boolean + isShowTextToSpeech: boolean + isShowCitation: boolean +} +const ChatGroup: FC = ({ + isShowOpeningStatement, + openingStatementConfig, + isShowSuggestedQuestionsAfterAnswer, + isShowSpeechText, + isShowTextToSpeech, + isShowCitation, +}) => { + const { t } = useTranslation() + + return ( +
+ +
+ {isShowOpeningStatement && ( + + )} + {isShowSuggestedQuestionsAfterAnswer && ( + + )} + { + isShowTextToSpeech && ( + + ) + } + { + isShowSpeechText && ( + + ) + } + { + isShowCitation && ( + + ) + } +
+
+ ) +} +export default React.memo(ChatGroup) diff --git a/web/app/components/app/configuration/features/chat-group/speech-to-text/index.tsx b/web/app/components/app/configuration/features/chat-group/speech-to-text/index.tsx new file mode 100644 index 0000000000..e452b38971 --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/speech-to-text/index.tsx @@ -0,0 +1,25 @@ +'use client' +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import { Microphone01 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' + +const SpeechToTextConfig: FC = () => { + const { t } = useTranslation() + + return ( + +
{t('appDebug.feature.speechToText.title')}
+ + } + headerIcon={} + headerRight={ +
{t('appDebug.feature.speechToText.resDes')}
+ } + noBodySpacing + /> + ) +} +export default React.memo(SpeechToTextConfig) diff --git a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx new file mode 100644 index 0000000000..199558f4aa --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx @@ -0,0 +1,34 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' +import Tooltip from '@/app/components/base/tooltip' + +const SuggestedQuestionsAfterAnswer: FC = () => { + const { t } = useTranslation() + + return ( + +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
+ + {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} + + } + /> + + } + headerIcon={} + headerRight={ +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
+ } + noBodySpacing + /> + ) +} +export default React.memo(SuggestedQuestionsAfterAnswer) diff --git a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx new file mode 100644 index 0000000000..72d617c3c3 --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx @@ -0,0 +1,55 @@ +'use client' +import useSWR from 'swr' +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { usePathname } from 'next/navigation' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import { Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import ConfigContext from '@/context/debug-configuration' +import { languages } from '@/i18n/language' +import { fetchAppVoices } from '@/service/apps' +import AudioBtn from '@/app/components/base/audio-btn' + +const TextToSpeech: FC = () => { + const { t } = useTranslation() + const { + textToSpeechConfig, + } = useContext(ConfigContext) + + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const language = textToSpeechConfig.language + const languageInfo = languages.find(i => i.value === textToSpeechConfig.language) + + const voiceItems = useSWR({ appId, language }, fetchAppVoices).data + const voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) + + return ( + +
{t('appDebug.feature.textToSpeech.title')}
+ + } + headerIcon={} + headerRight={ +
+ {languageInfo && (`${languageInfo?.name} - `)}{voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} + { languageInfo?.example && ( + + )} +
+ } + noBodySpacing + isShowTextToSpeech + /> + ) +} +export default React.memo(TextToSpeech) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx new file mode 100644 index 0000000000..809b907d62 --- /dev/null +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -0,0 +1,135 @@ +'use client' +import type { FC } from 'react' +import React, { useRef, useState } from 'react' +import { useHover } from 'ahooks' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { MessageCheckRemove, MessageFastPlus } from '@/app/components/base/icons/src/vender/line/communication' +import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' +import { Edit04 } from '@/app/components/base/icons/src/vender/line/general' +import RemoveAnnotationConfirmModal from '@/app/components/app/annotation/remove-annotation-confirm-modal' +import Tooltip from '@/app/components/base/tooltip' +import { addAnnotation, delAnnotation } from '@/service/annotation' +import Toast from '@/app/components/base/toast' +import { useProviderContext } from '@/context/provider-context' +import { useModalContext } from '@/context/modal-context' + +type Props = { + appId: string + messageId?: string + annotationId?: string + className?: string + cached: boolean + query: string + answer: string + onAdded: (annotationId: string, authorName: string) => void + onEdit: () => void + onRemoved: () => void +} + +const CacheCtrlBtn: FC = ({ + className, + cached, + query, + answer, + appId, + messageId, + annotationId, + onAdded, + onEdit, + onRemoved, +}) => { + const { t } = useTranslation() + const { plan, enableBilling } = useProviderContext() + const isAnnotationFull = (enableBilling && plan.usage.annotatedResponse >= plan.total.annotatedResponse) + const { setShowAnnotationFullModal } = useModalContext() + const [showModal, setShowModal] = useState(false) + const cachedBtnRef = useRef(null) + const isCachedBtnHovering = useHover(cachedBtnRef) + const handleAdd = async () => { + if (isAnnotationFull) { + setShowAnnotationFullModal() + return + } + const res: any = await addAnnotation(appId, { + message_id: messageId, + question: query, + answer, + }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + onAdded(res.id, res.account?.name) + } + + const handleRemove = async () => { + await delAnnotation(appId, annotationId!) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + onRemoved() + setShowModal(false) + } + return ( +
+
+ {cached + ? ( +
+
setShowModal(true)} + > + {!isCachedBtnHovering + ? ( + <> + +
{t('appDebug.feature.annotation.cached')}
+ + ) + : <> + +
{t('appDebug.feature.annotation.remove')}
+ } +
+
+ ) + : answer + ? ( + +
+ +
+
+ ) + : null + } + +
+ +
+
+ +
+ setShowModal(false)} + onRemove={handleRemove} + /> +
+ ) +} +export default React.memo(CacheCtrlBtn) diff --git a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx b/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx new file mode 100644 index 0000000000..b660977d08 --- /dev/null +++ b/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx @@ -0,0 +1,139 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import { useTranslation } from 'react-i18next' +import ScoreSlider from '../score-slider' +import { Item } from './config-param' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import Toast from '@/app/components/base/toast' +import type { AnnotationReplyConfig } from '@/models/debug' +import { ANNOTATION_DEFAULT } from '@/config' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' + +type Props = { + appId: string + isShow: boolean + onHide: () => void + onSave: (embeddingModel: { + embedding_provider_name: string + embedding_model_name: string + }, score: number) => void + isInit?: boolean + annotationConfig: AnnotationReplyConfig +} + +const ConfigParamModal: FC = ({ + isShow, + onHide: doHide, + onSave, + isInit, + annotationConfig: oldAnnotationConfig, +}) => { + const { t } = useTranslation() + const { + modelList: embeddingsModelList, + defaultModel: embeddingsDefaultModel, + currentModel: isEmbeddingsDefaultModelValid, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textEmbedding) + const [annotationConfig, setAnnotationConfig] = useState(oldAnnotationConfig) + + const [isLoading, setLoading] = useState(false) + const [embeddingModel, setEmbeddingModel] = useState(oldAnnotationConfig.embedding_model + ? { + providerName: oldAnnotationConfig.embedding_model.embedding_provider_name, + modelName: oldAnnotationConfig.embedding_model.embedding_model_name, + } + : (embeddingsDefaultModel + ? { + providerName: embeddingsDefaultModel.provider.provider, + modelName: embeddingsDefaultModel.model, + } + : undefined)) + const onHide = () => { + if (!isLoading) + doHide() + } + + const handleSave = async () => { + if (!embeddingModel || !embeddingModel.modelName || (embeddingModel.modelName === embeddingsDefaultModel?.model && !isEmbeddingsDefaultModelValid)) { + Toast.notify({ + message: t('common.modelProvider.embeddingModel.required'), + type: 'error', + }) + return + } + setLoading(true) + await onSave({ + embedding_provider_name: embeddingModel.providerName, + embedding_model_name: embeddingModel.modelName, + }, annotationConfig.score_threshold) + setLoading(false) + } + + return ( + +
+ {t(`appAnnotation.initSetup.${isInit ? 'title' : 'configTitle'}`)} +
+ +
+ + { + setAnnotationConfig({ + ...annotationConfig, + score_threshold: val / 100, + }) + }} + /> + + + +
+ { + setEmbeddingModel({ + providerName: val.provider, + modelName: val.model, + }) + }} + /> +
+
+
+ +
+ + +
+
+ ) +} +export default React.memo(ConfigParamModal) diff --git a/web/app/components/app/configuration/toolbox/annotation/type.ts b/web/app/components/app/configuration/toolbox/annotation/type.ts new file mode 100644 index 0000000000..910453478c --- /dev/null +++ b/web/app/components/app/configuration/toolbox/annotation/type.ts @@ -0,0 +1,4 @@ +export enum PageType { + log = 'log', + annotation = 'annotation', +} diff --git a/web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts b/web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts new file mode 100644 index 0000000000..540302cb27 --- /dev/null +++ b/web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts @@ -0,0 +1,89 @@ +import React, { useState } from 'react' +import produce from 'immer' +import type { AnnotationReplyConfig } from '@/models/debug' +import { queryAnnotationJobStatus, updateAnnotationStatus } from '@/service/annotation' +import type { EmbeddingModelConfig } from '@/app/components/app/annotation/type' +import { AnnotationEnableStatus, JobStatus } from '@/app/components/app/annotation/type' +import { sleep } from '@/utils' +import { ANNOTATION_DEFAULT } from '@/config' +import { useProviderContext } from '@/context/provider-context' + +type Params = { + appId: string + annotationConfig: AnnotationReplyConfig + setAnnotationConfig: (annotationConfig: AnnotationReplyConfig) => void +} +const useAnnotationConfig = ({ + appId, + annotationConfig, + setAnnotationConfig, +}: Params) => { + const { plan, enableBilling } = useProviderContext() + const isAnnotationFull = (enableBilling && plan.usage.annotatedResponse >= plan.total.annotatedResponse) + const [isShowAnnotationFullModal, setIsShowAnnotationFullModal] = useState(false) + const [isShowAnnotationConfigInit, doSetIsShowAnnotationConfigInit] = React.useState(false) + const setIsShowAnnotationConfigInit = (isShow: boolean) => { + if (isShow) { + if (isAnnotationFull) { + setIsShowAnnotationFullModal(true) + return + } + } + doSetIsShowAnnotationConfigInit(isShow) + } + const ensureJobCompleted = async (jobId: string, status: AnnotationEnableStatus) => { + let isCompleted = false + while (!isCompleted) { + const res: any = await queryAnnotationJobStatus(appId, status, jobId) + isCompleted = res.job_status === JobStatus.completed + if (isCompleted) + break + + await sleep(2000) + } + } + + const handleEnableAnnotation = async (embeddingModel: EmbeddingModelConfig, score?: number) => { + if (isAnnotationFull) + return + + const { job_id: jobId }: any = await updateAnnotationStatus(appId, AnnotationEnableStatus.enable, embeddingModel, score) + await ensureJobCompleted(jobId, AnnotationEnableStatus.enable) + setAnnotationConfig(produce(annotationConfig, (draft: AnnotationReplyConfig) => { + draft.enabled = true + draft.embedding_model = embeddingModel + if (!draft.score_threshold) + draft.score_threshold = ANNOTATION_DEFAULT.score_threshold + })) + } + + const setScore = (score: number, embeddingModel?: EmbeddingModelConfig) => { + setAnnotationConfig(produce(annotationConfig, (draft: AnnotationReplyConfig) => { + draft.score_threshold = score + if (embeddingModel) + draft.embedding_model = embeddingModel + })) + } + + const handleDisableAnnotation = async (embeddingModel: EmbeddingModelConfig) => { + if (!annotationConfig.enabled) + return + + await updateAnnotationStatus(appId, AnnotationEnableStatus.disable, embeddingModel) + setAnnotationConfig(produce(annotationConfig, (draft: AnnotationReplyConfig) => { + draft.enabled = false + })) + } + + return { + handleEnableAnnotation, + handleDisableAnnotation, + isShowAnnotationConfigInit, + setIsShowAnnotationConfigInit, + isShowAnnotationFullModal, + setIsShowAnnotationFullModal, + setScore, + } +} + +export default useAnnotationConfig diff --git a/web/app/components/app/configuration/toolbox/moderation/form-generation.tsx b/web/app/components/app/configuration/toolbox/moderation/form-generation.tsx new file mode 100644 index 0000000000..daf964447b --- /dev/null +++ b/web/app/components/app/configuration/toolbox/moderation/form-generation.tsx @@ -0,0 +1,79 @@ +import type { FC } from 'react' +import { useContext } from 'use-context-selector' +import type { CodeBasedExtensionForm } from '@/models/common' +import I18n from '@/context/i18n' +import { PortalSelect } from '@/app/components/base/select' +import type { ModerationConfig } from '@/models/debug' + +type FormGenerationProps = { + forms: CodeBasedExtensionForm[] + value: ModerationConfig['config'] + onChange: (v: Record) => void +} +const FormGeneration: FC = ({ + forms, + value, + onChange, +}) => { + const { locale } = useContext(I18n) + + const handleFormChange = (type: string, v: string) => { + onChange({ ...value, [type]: v }) + } + + return ( + <> + { + forms.map((form, index) => ( +
+
+ {locale === 'zh-Hans' ? form.label['zh-Hans'] : form.label['en-US']} +
+ { + form.type === 'text-input' && ( + handleFormChange(form.variable, e.target.value)} + /> + ) + } + { + form.type === 'paragraph' && ( +
+ +
+ ) + : ( +
+ )} + {renderQuestions()} + ) : ( +
{t('appDebug.openingStatement.noDataPlaceHolder')}
+ )} + + {isShowConfirmAddVar && ( + + )} + +
+
+ ) +} +export default React.memo(OpeningStatement) diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx b/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx new file mode 100644 index 0000000000..2e08a99122 --- /dev/null +++ b/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx @@ -0,0 +1,38 @@ +import ReactSlider from 'react-slider' +import s from './style.module.css' +import cn from '@/utils/classnames' + +type ISliderProps = { + className?: string + value: number + max?: number + min?: number + step?: number + disabled?: boolean + onChange: (value: number) => void +} + +const Slider: React.FC = ({ className, max, min, step, value, disabled, onChange }) => { + return ( +
+
+
+ {(state.valueNow / 100).toFixed(2)} +
+
+
+ )} + /> +} + +export default Slider diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css b/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css new file mode 100644 index 0000000000..4e93b39563 --- /dev/null +++ b/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css @@ -0,0 +1,20 @@ +.slider { + position: relative; +} + +.slider.disabled { + opacity: 0.6; +} + +.slider-thumb:focus { + outline: none; +} + +.slider-track { + background-color: #528BFF; + height: 2px; +} + +.slider-track-1 { + background-color: #E5E7EB; +} \ No newline at end of file diff --git a/web/app/components/base/features/feature-panel/score-slider/index.tsx b/web/app/components/base/features/feature-panel/score-slider/index.tsx new file mode 100644 index 0000000000..9826cbadcf --- /dev/null +++ b/web/app/components/base/features/feature-panel/score-slider/index.tsx @@ -0,0 +1,46 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import Slider from '@/app/components/app/configuration/toolbox/score-slider/base-slider' + +type Props = { + className?: string + value: number + onChange: (value: number) => void +} + +const ScoreSlider: FC = ({ + className, + value, + onChange, +}) => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+
+
0.8
+
·
+
{t('appDebug.feature.annotation.scoreThreshold.easyMatch')}
+
+
+
1.0
+
·
+
{t('appDebug.feature.annotation.scoreThreshold.accurateMatch')}
+
+
+
+ ) +} +export default React.memo(ScoreSlider) diff --git a/web/app/components/base/features/feature-panel/speech-to-text/index.tsx b/web/app/components/base/features/feature-panel/speech-to-text/index.tsx new file mode 100644 index 0000000000..2e5e3de439 --- /dev/null +++ b/web/app/components/base/features/feature-panel/speech-to-text/index.tsx @@ -0,0 +1,22 @@ +'use client' +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import { Microphone01 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' + +const SpeechToTextConfig: FC = () => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+
{t('appDebug.feature.speechToText.title')}
+
+
+
{t('appDebug.feature.speechToText.resDes')}
+
+ ) +} +export default React.memo(SpeechToTextConfig) diff --git a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx new file mode 100644 index 0000000000..e6d0b6e7e0 --- /dev/null +++ b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx @@ -0,0 +1,25 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { MessageSmileSquare } from '@/app/components/base/icons/src/vender/solid/communication' +import Tooltip from '@/app/components/base/tooltip' + +const SuggestedQuestionsAfterAnswer: FC = () => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
+ +
+
+
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
+
+ ) +} +export default React.memo(SuggestedQuestionsAfterAnswer) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/index.tsx b/web/app/components/base/features/feature-panel/text-to-speech/index.tsx new file mode 100644 index 0000000000..2480a19077 --- /dev/null +++ b/web/app/components/base/features/feature-panel/text-to-speech/index.tsx @@ -0,0 +1,62 @@ +'use client' +import useSWR from 'swr' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { usePathname } from 'next/navigation' +import { useFeatures } from '../../hooks' +import type { OnFeaturesChange } from '../../types' +import ParamsConfig from './params-config' +import { Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import { languages } from '@/i18n/language' +import { fetchAppVoices } from '@/service/apps' +import AudioBtn from '@/app/components/base/audio-btn' + +type TextToSpeechProps = { + onChange?: OnFeaturesChange + disabled?: boolean +} +const TextToSpeech = ({ + onChange, + disabled, +}: TextToSpeechProps) => { + const { t } = useTranslation() + const textToSpeech = useFeatures(s => s.features.text2speech) + + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const language = textToSpeech?.language + const languageInfo = languages.find(i => i.value === textToSpeech?.language) + + const voiceItems = useSWR({ appId, language }, fetchAppVoices).data + const voiceItem = voiceItems?.find(item => item.value === textToSpeech?.voice) + + return ( +
+
+ +
+
+ {t('appDebug.feature.textToSpeech.title')} +
+
+
+
+ {languageInfo && (`${languageInfo?.name} - `)}{voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} + { languageInfo?.example && ( + + )} +
+
+ +
+
+ ) +} +export default React.memo(TextToSpeech) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx new file mode 100644 index 0000000000..e923d9a333 --- /dev/null +++ b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx @@ -0,0 +1,241 @@ +'use client' +import useSWR from 'swr' +import produce from 'immer' +import React, { Fragment } from 'react' +import { usePathname } from 'next/navigation' +import { useTranslation } from 'react-i18next' +import { Listbox, Transition } from '@headlessui/react' +import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' +import { + useFeatures, + useFeaturesStore, +} from '../../hooks' +import type { OnFeaturesChange } from '../../types' +import classNames from '@/utils/classnames' +import type { Item } from '@/app/components/base/select' +import { fetchAppVoices } from '@/service/apps' +import Tooltip from '@/app/components/base/tooltip' +import { languages } from '@/i18n/language' +import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' +import { TtsAutoPlay } from '@/types/app' + +type VoiceParamConfigProps = { + onChange?: OnFeaturesChange +} +const VoiceParamConfig = ({ + onChange, +}: VoiceParamConfigProps) => { + const { t } = useTranslation() + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const text2speech = useFeatures(state => state.features.text2speech) + const featuresStore = useFeaturesStore() + + let languageItem = languages.find(item => item.value === text2speech?.language) + if (languages && !languageItem) + languageItem = languages[0] + const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') + + const language = languageItem?.value + const voiceItems = useSWR({ appId, language }, fetchAppVoices).data + let voiceItem = voiceItems?.find(item => item.value === text2speech?.voice) + if (voiceItems && !voiceItem) + voiceItem = voiceItems[0] + const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') + + const handleChange = (value: Record) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft.text2speech = { + ...draft.text2speech, + ...value, + } + }) + + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + } + + return ( +
+
+
{t('appDebug.voice.voiceSettings.title')}
+
+
+
+
{t('appDebug.voice.voiceSettings.language')}
+ + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item} +
+ ))} +
+ } + /> +
+ { + handleChange({ + language: String(value.value), + }) + }} + > +
+ + + {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} + + + + + + + + {languages.map((item: Item) => ( + + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' + }` + } + value={item} + disabled={false} + > + {({ /* active, */ selected }) => ( + <> + {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} + {(selected || item.value === text2speech?.language) && ( + + + )} + + )} + + ))} + + +
+
+
+ +
+
{t('appDebug.voice.voiceSettings.voice')}
+ { + handleChange({ + voice: String(value.value), + }) + }} + > +
+ + {voiceItem?.name ?? localVoicePlaceholder} + + + + + + + {voiceItems?.map((item: Item) => ( + + `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' + }` + } + value={item} + disabled={false} + > + {({ /* active, */ selected }) => ( + <> + {item.name} + {(selected || item.value === text2speech?.voice) && ( + + + )} + + )} + + ))} + + +
+
+
+
+
{t('appDebug.voice.voiceSettings.autoPlay')}
+ { + handleChange({ + autoPlay: value, + }) + }} + /> +
+
+
+ + ) +} + +export default React.memo(VoiceParamConfig) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx b/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx new file mode 100644 index 0000000000..095fd6cce8 --- /dev/null +++ b/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx @@ -0,0 +1,48 @@ +'use client' +import { memo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import type { OnFeaturesChange } from '../../types' +import ParamConfigContent from './param-config-content' +import cn from '@/utils/classnames' +import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' + +type ParamsConfigProps = { + onChange?: OnFeaturesChange + disabled?: boolean +} +const ParamsConfig = ({ + onChange, + disabled, +}: ParamsConfigProps) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + + return ( + + !disabled && setOpen(v => !v)}> +
+ +
{t('appDebug.voice.settings')}
+
+
+ +
+ +
+
+
+ ) +} +export default memo(ParamsConfig) diff --git a/web/app/signin/forms.tsx b/web/app/signin/forms.tsx new file mode 100644 index 0000000000..70a34c26fa --- /dev/null +++ b/web/app/signin/forms.tsx @@ -0,0 +1,34 @@ +'use client' +import React from 'react' +import { useSearchParams } from 'next/navigation' + +import NormalForm from './normalForm' +import OneMoreStep from './oneMoreStep' +import cn from '@/utils/classnames' + +const Forms = () => { + const searchParams = useSearchParams() + const step = searchParams.get('step') + + const getForm = () => { + switch (step) { + case 'next': + return + default: + return + } + } + return
+
+ {getForm()} +
+
+} + +export default Forms diff --git a/web/app/signin/userSSOForm.tsx b/web/app/signin/userSSOForm.tsx new file mode 100644 index 0000000000..f01afa9eaf --- /dev/null +++ b/web/app/signin/userSSOForm.tsx @@ -0,0 +1,107 @@ +'use client' +import { useRouter, useSearchParams } from 'next/navigation' +import type { FC } from 'react' +import { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import Toast from '@/app/components/base/toast' +import { getUserOAuth2SSOUrl, getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' +import Button from '@/app/components/base/button' +import useRefreshToken from '@/hooks/use-refresh-token' + +type UserSSOFormProps = { + protocol: string +} + +const UserSSOForm: FC = ({ + protocol, +}) => { + const { getNewAccessToken } = useRefreshToken() + const searchParams = useSearchParams() + const consoleToken = searchParams.get('access_token') + const refreshToken = searchParams.get('refresh_token') + const message = searchParams.get('message') + + const router = useRouter() + const { t } = useTranslation() + + const [isLoading, setIsLoading] = useState(false) + + useEffect(() => { + if (refreshToken && consoleToken) { + localStorage.setItem('console_token', consoleToken) + localStorage.setItem('refresh_token', refreshToken) + getNewAccessToken() + router.replace('/apps') + } + + if (message) { + Toast.notify({ + type: 'error', + message, + }) + } + }, [consoleToken, refreshToken, message, router]) + + const handleSSOLogin = () => { + setIsLoading(true) + if (protocol === 'saml') { + getUserSAMLSSOUrl().then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === 'oidc') { + getUserOIDCSSOUrl().then((res) => { + document.cookie = `user-oidc-state=${res.state}` + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === 'oauth2') { + getUserOAuth2SSOUrl().then((res) => { + document.cookie = `user-oauth2-state=${res.state}` + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else { + Toast.notify({ + type: 'error', + message: 'invalid SSO protocol', + }) + setIsLoading(false) + } + } + + return ( +
+
+
+

{t('login.pageTitle')}

+
+
+ +
+
+
+ ) +} + +export default UserSSOForm