diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..4046417076 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource): return code_result +class RuleStructuredOutputGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + account = current_user + try: + structured_output = LLMGenerator.generate_structured_output( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return structured_output + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index dc0009f36e..d4a33645ab 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -16,7 +16,7 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -30,6 +30,7 @@ from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource): AccountService.add_forgot_password_error_rate_limit(args["email"]) raise EmailCodeError() + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + AccountService.reset_forgot_password_error_rate_limit(args["email"]) - return {"is_valid": True, "email": token_data.get("email")} + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} class ForgotPasswordResetApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, nullable=False, location="json") @@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource): reset_data = AccountService.get_reset_password_data(args["token"]) if not reset_data: raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() # Revoke token to prevent reuse AccountService.revoke_reset_password_token(args["token"]) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 41362e9fa2..16c1dcc441 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,7 @@ from controllers.console.error import ( EmailSendIpLimitError, NotAllowedCreateWorkspace, ) -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password @@ -38,6 +38,7 @@ class LoginApi(Resource): """Resource for user login.""" @setup_required + @email_password_login_enabled def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -110,6 +111,7 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 6caaae87f4..e5e8038ad7 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -210,3 +210,16 @@ def enterprise_license_required(view): return view(*args, **kwargs) return decorated + + +def email_password_login_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_email_password_login: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 75687f9ae3..d5d2ca60fa 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from core.llm_generator.prompts import ( GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -340,3 +341,37 @@ class LLMGenerator: answer = cast(str, response.message.content) return answer.strip() + + @classmethod + def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [ + SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), + UserPromptMessage(content=instruction), + ] + model_parameters = model_config.get("model_parameters", {}) + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_json_schema = cast(str, response.message.content) + return {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index cf20e60c82..82d22d7f89 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output """ # noqa: E501 + +SYSTEM_STRUCTURED_OUTPUT_GENERATE = """ +Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. + +## Instructions: + +1. Analyze the user's description of their data needs +2. Identify each property that should be included in the schema +3. Determine the appropriate data type for each property +4. Decide which properties should be required +5. Generate a complete JSON Schema with proper syntax +6. Include appropriate constraints when specified (min/max values, patterns, formats) +7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. +8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. + +## Examples: + +### Example 1: +**User Input:** I need name and age +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name", "age"] +} + +### Example 2: +**User Input:** I want to store information about books including title, author, publication year and optional page count +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "title": { "type": "string" }, + "author": { "type": "string" }, + "publicationYear": { "type": "integer" }, + "pageCount": { "type": "integer" } + }, + "required": ["title", "author", "publicationYear"] +} + +### Example 3: +**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + }, + "password": { + "type": "string", + "minLength": 8 + }, + "age": { + "type": "integer", + "minimum": 18 + } + }, + "required": ["email", "password", "age"] +} + +### Example 4: +**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "properties": { + "songs": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "id": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "aritst": { + "type": "string" + } + }, + "required": [ + "name", + "id", + "duration", + "aritst" + ] + } + } + } + }, + "required": [ + "songs" + ] +} + +Now, generate a JSON Schema based on my description +""" # noqa: E501 diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 977678b893..3bed2460dd 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Optional +from typing import Any, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_serializer, field_validator class PromptMessageRole(Enum): @@ -135,6 +135,16 @@ class PromptMessage(BaseModel): """ return not self.content + @field_serializer("content") + def serialize_content( + self, content: Optional[Union[str, Sequence[PromptMessageContent]]] + ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]: + if content is None or isinstance(content, str): + return content + if isinstance(content, list): + return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] + return content + class UserPromptMessage(PromptMessage): """ diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3225f03fbd..373ef2bbe2 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject @@ -85,6 +85,7 @@ class ModelFeature(Enum): DOCUMENT = "document" VIDEO = "video" AUDIO = "audio" + STRUCTURED_OUTPUT = "structured-output" class DefaultParameterName(StrEnum): @@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel): parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None + @model_validator(mode="after") + def validate_model(self): + supported_schema_keys = ["json_schema"] + schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) + if not schema_key: + return self + if self.features is None: + self.features = [ModelFeature.STRUCTURED_OUTPUT] + else: + if ModelFeature.STRUCTURED_OUTPUT not in self.features: + self.features.append(ModelFeature.STRUCTURED_OUTPUT) + return self + class ModelUsage(BaseModel): pass diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index f402da030f..db07e52f3f 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = ParameterExtractorNodeData( @@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = QuestionClassifierNodeData( diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 4af2578197..63695e6f3f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -2,12 +2,12 @@ import array import json import re import uuid -from contextlib import contextmanager from typing import Any import jieba.posseg as pseg # type: ignore import numpy import oracledb +from oracledb.connection import Connection from pydantic import BaseModel, model_validator from configs import dify_config @@ -70,6 +70,7 @@ class OracleVector(BaseVector): super().__init__(collection_name) self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + self.config = config def get_type(self) -> str: return VectorType.ORACLE @@ -107,16 +108,19 @@ class OracleVector(BaseVector): outconverter=self.numpy_converter_out, ) + def _get_connection(self) -> Connection: + connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn) + return connection + def _create_connection_pool(self, config: OracleVectorConfig): pool_params = { "user": config.user, "password": config.password, "dsn": config.dsn, "min": 1, - "max": 50, + "max": 5, "increment": 1, } - if config.is_autonomous: pool_params.update( { @@ -125,22 +129,8 @@ class OracleVector(BaseVector): "wallet_password": config.wallet_password, } ) - return oracledb.create_pool(**pool_params) - @contextmanager - def _get_cursor(self): - conn = self.pool.acquire() - conn.inputtypehandler = self.input_type_handler - conn.outputtypehandler = self.output_type_handler - cur = conn.cursor() - try: - yield cur - finally: - cur.close() - conn.commit() - conn.close() - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) @@ -162,41 +152,68 @@ class OracleVector(BaseVector): numpy.array(embeddings[i]), ) ) - # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") - with self._get_cursor() as cur: - cur.executemany( - f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values - ) + with self._get_connection() as conn: + conn.inputtypehandler = self.input_type_handler + conn.outputtypehandler = self.output_type_handler + # with conn.cursor() as cur: + # cur.executemany( + # f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + # ) + # conn.commit() + for value in values: + with conn.cursor() as cur: + try: + cur.execute( + f"""INSERT INTO {self.table_name} (id, text, meta, embedding) + VALUES (:1, :2, :3, :4)""", + value, + ) + conn.commit() + except Exception as e: + print(e) + conn.close() return pks def text_exists(self, id: str) -> bool: - with self._get_cursor() as cur: - cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) - return cur.fetchone() is not None + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) + return cur.fetchone() is not None + conn.close() def get_by_ids(self, ids: list[str]) -> list[Document]: - with self._get_cursor() as cur: - cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) - docs = [] - for record in cur: - docs.append(Document(page_content=record[1], metadata=record[0])) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + docs = [] + for record in cur: + docs.append(Document(page_content=record[1], metadata=record[0])) + self.pool.release(connection=conn) + conn.close() return docs def delete_by_ids(self, ids: list[str]) -> None: if not ids: return - with self._get_cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + conn.commit() + conn.close() def delete_by_metadata_field(self, key: str, value: str) -> None: - with self._get_cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + conn.commit() + conn.close() def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. :param query_vector: The input vector to search for similar items. + :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) @@ -205,20 +222,25 @@ class OracleVector(BaseVector): if document_ids_filter: document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" - with self._get_cursor() as cur: - cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" - f" {where_clause} ORDER BY distance fetch first {top_k} rows only", - [numpy.array(query_vector)], - ) - docs = [] - score_threshold = float(kwargs.get("score_threshold") or 0.0) - for record in cur: - metadata, text, distance = record - score = 1 - distance - metadata["score"] = score - if score > score_threshold: - docs.append(Document(page_content=text, metadata=metadata)) + with self._get_connection() as conn: + conn.inputtypehandler = self.input_type_handler + conn.outputtypehandler = self.output_type_handler + with conn.cursor() as cur: + cur.execute( + f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) + AS distance FROM {self.table_name} + {where_clause} ORDER BY distance fetch first {top_k} rows only""", + [numpy.array(query_vector)], + ) + docs = [] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + for record in cur: + metadata, text, distance = record + score = 1 - distance + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + conn.close() return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -228,7 +250,7 @@ class OracleVector(BaseVector): top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - # score_threshold = float(kwargs.get("score_threshold") or 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: # Check which language the query is in zh_pattern = re.compile("[\u4e00-\u9fa5]+") @@ -239,7 +261,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -260,30 +282,35 @@ class OracleVector(BaseVector): for token in all_tokens: if token not in stop_words: entities.append(token) - with self._get_cursor() as cur: - document_ids_filter = kwargs.get("document_ids_filter") - where_clause = "" - if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'document_id' in ({document_ids}) " - cur.execute( - f"select meta, text, embedding FROM {self.table_name}" - f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " - f"order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)], - ) - docs = [] - for record in cur: - metadata, text, embedding = record - docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + with self._get_connection() as conn: + with conn.cursor() as cur: + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" AND metadata->>'document_id' in ({document_ids}) " + cur.execute( + f"""select meta, text, embedding FROM {self.table_name} + WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} + order by score(1) desc fetch first {top_k} rows only""", + kk=" ACCUM ".join(entities), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + conn.close() return docs else: return [Document(page_content="", metadata={})] return [] def delete(self) -> None: - with self._get_cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") + conn.commit() + conn.close() def _create_collection(self, dimension: int): cache_key = f"vector_indexing_{self._collection_name}" @@ -293,11 +320,14 @@ class OracleVector(BaseVector): if redis_client.get(collection_exist_cache_key): return - with self._get_cursor() as cur: - cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) - redis_client.set(collection_exist_cache_key, 1, ex=3600) - with self._get_cursor() as cur: - cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + with conn.cursor() as cur: + cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + conn.commit() + conn.close() class OracleVectorFactory(AbstractVectorFactory): diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 70c618a631..edaa8c92fa 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" return image_map diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py index 6dea4566de..9bb790cb0f 100644 --- a/api/core/repository/workflow_node_execution_repository.py +++ b/api/core/repository/workflow_node_execution_repository.py @@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol): execution: The WorkflowNodeExecution instance to update """ ... + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records based on implementation-specific criteria. + + This method is intended to be used for bulk deletion operations, such as removing + all records associated with a specific app_id and tenant_id in multi-tenant implementations. + """ + ... diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f661294ec4..f5838c3b76 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -94,7 +94,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): "title": item.metadata.get("title"), "content": item.page_content, } - context_list.append(source) + context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7c8960fe49..da40cbcdea 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -16,7 +16,7 @@ from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated +from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -251,7 +251,12 @@ class AgentNode(ToolNode): prompt_message.model_dump(mode="json") for prompt_message in prompt_messages ] value["history_prompt_messages"] = history_prompt_messages - value["entity"] = model_schema.model_dump(mode="json") if model_schema else None + if model_schema: + # remove structured output feature to support old version agent plugin + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None result[parameter_name] = value return result @@ -348,3 +353,10 @@ class AgentNode(ToolNode): ) model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_instance, model_schema + + def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features: + if feature.value not in AgentOldVersionModelFeatures: + model_schema.features.remove(feature) + return model_schema diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 87cc7e9824..77e94375bf 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(Enum): CLOSE = 0 OPEN = 1 + + +class AgentOldVersionModelFeatures(Enum): + """ + Enum class for old SDK version llm feature. + """ + + TOOL_CALL = "tool-call" + MULTI_TOOL_CALL = "multi-tool-call" + AGENT_THOUGHT = "agent-thought" + VISION = "vision" + STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bf54fdb80c..486b4b01af 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: dict | None = None + structured_output_enabled: bool = False @field_validator("prompt_config", mode="before") @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..8db7394e54 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast +import json_repair + from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus @@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -57,6 +65,12 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.entities import ( + ResponseFormat, + SpecialModelType, + SupportStructuredOutputStatus, +) +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]): _node_type = NodeType.LLM def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: + """Process structured output if enabled""" + if not self.node_data.structured_output_enabled or not self.node_data.structured_output: + return None + return self._parse_structured_output(text) + node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(event, RunRetrieverResourceEvent): context = event.context yield event - if context: node_inputs["#context#"] = context @@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - + structured_output = process_structured_output(result_text) + if structured_output: + outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {model_name} not exist.") - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, model=model_name, @@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) stop = model_config.stop return filtered_prompt_messages, stop + def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]: + structured_output: dict[str, Any] | list[Any] = {} + try: + parsed = json.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + except json.JSONDecodeError as e: + # if the result_text is not a valid json, try to repair it + parsed = json_repair.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + return structured_output + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema = self._fetch_structured_output_schema() + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.node_data.model.name: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.node_data.model.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_name = self.node_data.model.name + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name + ) + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_credentials = model_instance.credentials + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_schema + + def _fetch_structured_output_schema(self) -> dict[str, Any]: + """ + Fetch the structured output schema from the node data. + + Returns: + dict[str, Any]: The structured output schema + """ + if not self.node_data.structured_output: + raise LLMNodeError("Please provide a valid structured output schema") + structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + if not structured_output_schema: + raise LLMNodeError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + if not isinstance(schema, dict): + raise LLMNodeError("structured_output_schema must be a JSON object") + return schema + except json.JSONDecodeError: + raise LLMNodeError("structured_output_schema is not valid JSON format") + + def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: + """ + Check if the current model supports structured output. + + Returns: + SupportStructuredOutput: The support status of structured output + """ + # Early return if structured output is disabled + if ( + not isinstance(self.node_data, LLMNodeData) + or not self.node_data.structured_output_enabled + or not self.node_data.structured_output + ): + return SupportStructuredOutputStatus.DISABLED + # Get model schema and check if it exists + model_schema = self._fetch_model_schema(self.node_data.model.provider) + if not model_schema: + return SupportStructuredOutputStatus.DISABLED + + # Check if model supports structured output feature + return ( + SupportStructuredOutputStatus.SUPPORTED + if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) + else SupportStructuredOutputStatus.UNSUPPORTED + ) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -1064,3 +1269,49 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py new file mode 100644 index 0000000000..7954acbaee --- /dev/null +++ b/api/core/workflow/utils/structured_output/entities.py @@ -0,0 +1,24 @@ +from enum import StrEnum + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. + JSON = "JSON" # model's json mode. some model like claude support this mode. + JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" + + +class SupportStructuredOutputStatus(StrEnum): + """Constants for structured output support status""" + + SUPPORTED = "supported" + UNSUPPORTED = "unsupported" + DISABLED = "disabled" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py new file mode 100644 index 0000000000..06d9b2056e --- /dev/null +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -0,0 +1,17 @@ +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: +{{schema}} +""" # noqa: E501 diff --git a/api/models/workflow.py b/api/models/workflow.py index 045fa0aaa0..51f2f4cc9f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property @@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base): from models.model import EndUser created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 85679a6359..4992178423 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "gunicorn~=23.0.0", "httpx[socks]~=0.27.0", "jieba==0.42.1", + "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "mailchimp-transactional~=1.0.50", @@ -163,10 +164,7 @@ storage = [ ############################################################ # [ Tools ] dependency group ############################################################ -tools = [ - "cloudscraper~=1.2.71", - "nltk~=3.9.1", -] +tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] ############################################################ # [ VDB ] dependency group @@ -180,7 +178,7 @@ vdb = [ "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", - "oracledb~=2.2.1", + "oracledb==3.0.0", "pgvecto-rs[sqlalchemy]~=0.2.1", "pgvector==0.2.5", "pymilvus~=2.5.0", diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py index c9c6e70ff3..0594d816a2 100644 --- a/api/repositories/workflow_node_execution/sqlalchemy_repository.py +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -6,7 +6,7 @@ import logging from collections.abc import Sequence from typing import Optional -from sqlalchemy import UnaryExpression, asc, desc, select +from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository: session.merge(execution) session.commit() + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records for the current tenant_id and app_id. + + This method deletes all WorkflowNodeExecution records that match the tenant_id + and app_id (if provided) associated with this repository instance. + """ + with self._session_factory() as session: + stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + result = session.execute(stmt) + session.commit() + + deleted_count = result.rowcount + logger.info( + f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + + (f" and app {self._app_id}" if self._app_id else "") + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index ada8109067..f930ef910b 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -407,10 +407,8 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data={"code": code} - ) + code, token = cls.generate_reset_password_token(account_email, account) + send_reset_password_mail_task.delay( language=language, to=account_email, @@ -419,6 +417,22 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def generate_reset_password_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 0ddd18ea27..ff3b33eecd 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,13 +2,14 @@ import threading from typing import Optional import contexts +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, ) @@ -127,17 +128,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ) - .order_by(WorkflowNodeExecution.index.desc()) - .all() + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } ) - return node_executions + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 992942fc70..b88c7b296d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository import RepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -282,8 +283,15 @@ class WorkflowService: workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id - db.session.add(workflow_node_execution) - db.session.commit() + # Use the repository to save the workflow node execution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } + ) + repository.save(workflow_node_execution) return workflow_node_execution diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index c3910e2be3..4542b1b923 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -7,6 +7,7 @@ from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from core.repository import RepositoryFactory from extensions.ext_database import db from models.dataset import AppDatasetJoin from models.model import ( @@ -30,7 +31,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( - synchronize_session=False - ) - - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + # Create a repository instance for WorkflowNodeExecution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": tenant_id, + "app_id": app_id, + "session_factory": db.session.get_bind, + } ) + # Use the clear method to delete all records for this tenant_id and app_id + repository.clear() + + logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green")) + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index f31adab2a8..36847f8a13 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -152,3 +152,27 @@ def test_update(repository, session): # Assert session.merge was called session_obj.merge.assert_called_once_with(execution) + + +def test_clear(repository, session, mocker: MockerFixture): + """Test clear method.""" + session_obj, _ = session + # Set up mock + mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete") + mock_stmt = mocker.MagicMock() + mock_delete.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + + # Mock the execute result with rowcount + mock_result = mocker.MagicMock() + mock_result.rowcount = 5 # Simulate 5 records deleted + session_obj.execute.return_value = mock_result + + # Call method + repository.clear() + + # Assert delete was called with correct parameters + mock_delete.assert_called_once_with(WorkflowNodeExecution) + mock_stmt.where.assert_called() + session_obj.execute.assert_called_once_with(mock_stmt) + session_obj.commit.assert_called_once() diff --git a/api/uv.lock b/api/uv.lock index 4ff9c34446..6c8699dd7c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'", @@ -1178,6 +1177,7 @@ dependencies = [ { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "jieba" }, + { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, { name = "mailchimp-transactional" }, @@ -1346,6 +1346,7 @@ requires-dist = [ { name = "gunicorn", specifier = "~=23.0.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "jieba", specifier = "==0.42.1" }, + { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" }, @@ -1470,7 +1471,7 @@ vdb = [ { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, - { name = "oracledb", specifier = "~=2.2.1" }, + { name = "oracledb", specifier = "==3.0.0" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, @@ -2524,6 +2525,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, ] +[[package]] +name = "json-repair" +version = "0.41.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 }, +] + [[package]] name = "jsonpath-python" version = "1.0.6" @@ -3590,23 +3600,23 @@ wheels = [ [[package]] name = "oracledb" -version = "2.2.1" +version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 } +sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 }, - { url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 }, - { url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 }, - { url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 }, - { url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 }, - { url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 }, - { url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 }, - { url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 }, - { url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 }, - { url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 }, + { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 }, + { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 }, + { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 }, + { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 }, + { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 }, + { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 }, + { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 }, + { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 }, + { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 }, + { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 }, ] [[package]] @@ -4074,6 +4084,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 }, { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 }, { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 }, + { url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 }, + { url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 }, ] [[package]] diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index a8f7b755fb..c6d41849ef 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -130,6 +130,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 27d6d660d0..1702a5395f 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -60,6 +60,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e01b9f7e9a..def4b77c65 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -603,6 +603,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/web/README.md b/web/README.md index 3236347e80..3d9fd2de87 100644 --- a/web/README.md +++ b/web/README.md @@ -7,7 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next ### Run by source code Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) >= v18.x +- [Node.js](https://nodejs.org) >= v22.11.x - [pnpm](https://pnpm.io) v10.x First, install the dependencies: diff --git a/web/app/components/app/configuration/config-var/config-select/index.spec.tsx b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx new file mode 100644 index 0000000000..18df318de3 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx @@ -0,0 +1,82 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ConfigSelect from './index' + +jest.mock('react-sortablejs', () => ({ + ReactSortable: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('ConfigSelect Component', () => { + const defaultProps = { + options: ['Option 1', 'Option 2'], + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders all options', () => { + render() + + defaultProps.options.forEach((option) => { + expect(screen.getByDisplayValue(option)).toBeInTheDocument() + }) + }) + + it('renders add button', () => { + render() + + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) + + it('handles option deletion', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.click(deleteButton) + expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2']) + }) + + it('handles adding new option', () => { + render() + const addButton = screen.getByText('appDebug.variableConfig.addOption') + + fireEvent.click(addButton) + + expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, '']) + }) + + it('applies focus styles on input focus', () => { + render() + const firstInput = screen.getByDisplayValue('Option 1') + + fireEvent.focus(firstInput) + + expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active') + }) + + it('applies delete hover styles', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.mouseEnter(deleteButton) + expect(optionContainer).toHaveClass('border-components-input-border-destructive') + }) + + it('renders empty state correctly', () => { + render() + + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx index d2dc1662c1..40ddaef78f 100644 --- a/web/app/components/app/configuration/config-var/config-select/index.tsx +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -51,7 +51,7 @@ const ConfigSelect: FC = ({ { const value = e.target.value @@ -67,6 +67,7 @@ const ConfigSelect: FC = ({ onBlur={() => setFocusID(null)} />
{ onChange(options.filter((_, i) => index !== i)) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 896229d433..679d616e54 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -162,11 +162,22 @@ const SettingsModal: FC = ({ return check } + const validatePrivacyPolicy = (privacyPolicy: string | null) => { + if (privacyPolicy === null || privacyPolicy?.length === 0) + return true + + return privacyPolicy.startsWith('http://') || privacyPolicy.startsWith('https://') + } + if (inputInfo !== null) { if (!validateColorHex(inputInfo.chatColorTheme)) { notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`) }) return } + if (!validatePrivacyPolicy(inputInfo.privacyPolicy)) { + notify({ type: 'error', message: t(`${prefixSettings}.invalidPrivacyPolicy`) }) + return + } } setSaveLoading(true) @@ -410,7 +421,7 @@ const SettingsModal: FC = ({

}} + components={{ privacyPolicyLink: }} />

= ({ ) } -export default memo(Answer) +export default memo(Answer, (prevProps, nextProps) => + prevProps.responding === false && nextProps.responding === false, +) diff --git a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx index 7c8c609e08..56df8db6a4 100644 --- a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx +++ b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx @@ -1,8 +1,10 @@ const IndeterminateIcon = () => { return ( - - - +
+ + + +
) } diff --git a/web/app/components/base/checkbox/index.spec.tsx b/web/app/components/base/checkbox/index.spec.tsx new file mode 100644 index 0000000000..7ef901aef5 --- /dev/null +++ b/web/app/components/base/checkbox/index.spec.tsx @@ -0,0 +1,67 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Checkbox from './index' + +describe('Checkbox Component', () => { + const mockProps = { + id: 'test', + } + + it('renders unchecked checkbox by default', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toBeInTheDocument() + expect(checkbox).not.toHaveClass('bg-components-checkbox-bg') + }) + + it('renders checked checkbox when checked prop is true', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg') + expect(screen.getByTestId('check-icon-test')).toBeInTheDocument() + }) + + it('renders indeterminate state correctly', () => { + render() + expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument() + }) + + it('handles click events when not disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).toHaveBeenCalledTimes(1) + }) + + it('does not handle click events when disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).not.toHaveBeenCalled() + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies custom className when provided', () => { + const customClass = 'custom-class' + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass(customClass) + }) + + it('applies correct styles for disabled checked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies correct styles for disabled unchecked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) +}) diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx index 99a31234f7..3e47967c62 100644 --- a/web/app/components/base/checkbox/index.tsx +++ b/web/app/components/base/checkbox/index.tsx @@ -40,9 +40,10 @@ const Checkbox = ({ return onCheck?.() }} + data-testid={`checkbox-${id}`} > {!checked && indeterminate && } - {checked && } + {checked && }
) } diff --git a/web/app/components/base/form/components/label.spec.tsx b/web/app/components/base/form/components/label.spec.tsx new file mode 100644 index 0000000000..b2dc98a21e --- /dev/null +++ b/web/app/components/base/form/components/label.spec.tsx @@ -0,0 +1,53 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Label from './label' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('Label Component', () => { + const defaultProps = { + htmlFor: 'test-input', + label: 'Test Label', + } + + it('renders basic label correctly', () => { + render(