From 3e6a6bf396f0a493715ded04c08ee6cc2b705f64 Mon Sep 17 00:00:00 2001 From: "teruo OSHIDA(JP_SMN)" <37995617+monstaruos@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:21:31 +0900 Subject: [PATCH 01/24] fix: wrong usage of created_at on the modal for API Key (#7548) --- web/app/components/develop/secret-key/secret-key-modal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index 01973ff028..fc780be05d 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -115,7 +115,7 @@ const SecretKeyModal = ({
{generateToken(api.token)}
{formatTime(Number(api.created_at), t('appLog.dateTimeFormat') as string)}
-
{api.last_used_at ? formatTime(Number(api.created_at), t('appLog.dateTimeFormat') as string) : t('appApi.never')}
+
{api.last_used_at ? formatTime(Number(api.last_used_at), t('appLog.dateTimeFormat') as string) : t('appApi.never')}
Date: Fri, 23 Aug 2024 08:33:41 +0800 Subject: [PATCH 02/24] fix: correct response structure in openapi documentation of app (#7556) --- .../components/develop/template/template_advanced_chat.en.mdx | 2 +- .../components/develop/template/template_advanced_chat.zh.mdx | 2 +- web/app/components/develop/template/template_chat.en.mdx | 4 ++-- web/app/components/develop/template/template_chat.zh.mdx | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index cec42a4e32..6487ac79d7 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -285,7 +285,7 @@ Chat applications support session persistence, allowing previous chat history to data: {"event": "message", "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " to", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " meet", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " you", "created_at": 1679586595} - data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 97a2df01a6..33551509e5 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -296,7 +296,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "message", "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " to", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " meet", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " you", "created_at": 1679586595} - data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 00c02936a6..07840640f4 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -242,7 +242,7 @@ Chat applications support session persistence, allowing previous chat history to data: {"event": "message", "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " to", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " meet", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " you", "created_at": 1679586595} - data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` @@ -256,7 +256,7 @@ Chat applications support session persistence, allowing previous chat history to data: {"event": "message", "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " to", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " meet", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " you", "created_at": 1679586595} - data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 9323313d9f..727d884c1a 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -255,7 +255,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "message", "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " to", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " meet", "created_at": 1679586595} data: {"event": "message", "message_id": : "5ad4cb98-f0c7-4085-b384-88c403be6290", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "answer": " you", "created_at": 1679586595} - data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` @@ -274,7 +274,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' data: {"event": "agent_message", "id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "answer": " eyes wearing a bunny girl" ,"created_at": 1705639511, "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} data: {"event": "agent_message", "id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "answer": " suit .", "created_at": 1705639511, "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} data: {"event": "agent_thought", "id": "67a99dc1-4f82-42d3-b354-18d4594840c8", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "position": 2, "thought": "I have created an image of a cute Japanese anime girl with white hair and blue eyes wearing a bunny girl suit.", "observation": "", "tool": "", "tool_input": "", "created_at": 1705639511, "message_files": [], "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142"} - data: {"event": "message_end", "task_id": "9cf1ddd7-f94b-459b-b942-b77b26c59e9b", "id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "message_id": "1fb10045-55fd-4040-99e6-d048d07cbad3", "conversation_id": "c216c595-2d89-438c-b33c-aae5ddddd142", "metadata": {"usage": {"prompt_tokens": 305, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0003050", "completion_tokens": 97, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0001940", "total_tokens": 184, "total_price": "0.0002290", "currency": "USD", "latency": 1.771092874929309}}} + data: {"event": "message_end", "id": "5e52ce04-874b-4d27-9045-b3bc80def685", "conversation_id": "45701982-8118-4bc5-8e9b-64562b4555f2", "metadata": {"usage": {"prompt_tokens": 1033, "prompt_unit_price": "0.001", "prompt_price_unit": "0.001", "prompt_price": "0.0010330", "completion_tokens": 135, "completion_unit_price": "0.002", "completion_price_unit": "0.001", "completion_price": "0.0002700", "total_tokens": 1168, "total_price": "0.0013030", "currency": "USD", "latency": 1.381760165997548}, "retriever_resources": [{"position": 1, "dataset_id": "101b4c97-fc2e-463c-90b1-5261a4cdcafb", "dataset_name": "iPhone", "document_id": "8dd1ad74-0b5f-4175-b735-7d98bbbb4e00", "document_name": "iPhone List", "segment_id": "ed599c7f-2766-4294-9d1d-e5235a61270a", "score": 0.98457545, "content": "\"Model\",\"Release Date\",\"Display Size\",\"Resolution\",\"Processor\",\"RAM\",\"Storage\",\"Camera\",\"Battery\",\"Operating System\"\n\"iPhone 13 Pro Max\",\"September 24, 2021\",\"6.7 inch\",\"1284 x 2778\",\"Hexa-core (2x3.23 GHz Avalanche + 4x1.82 GHz Blizzard)\",\"6 GB\",\"128, 256, 512 GB, 1TB\",\"12 MP\",\"4352 mAh\",\"iOS 15\""}]}} data: {"event": "tts_message", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq"} data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` From a40073afa446b4bb6bcd90477177baa6829949f1 Mon Sep 17 00:00:00 2001 From: Hanqing Zhao Date: Fri, 23 Aug 2024 08:34:22 +0800 Subject: [PATCH 03/24] Add N-to-1 warning translation for JP (#7553) --- web/i18n/ja-JP/dataset.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/i18n/ja-JP/dataset.ts b/web/i18n/ja-JP/dataset.ts index 2cdc819cd5..5df8aaea7f 100644 --- a/web/i18n/ja-JP/dataset.ts +++ b/web/i18n/ja-JP/dataset.ts @@ -67,7 +67,9 @@ const translation = { semantic: 'セマンティクス', keyword: 'キーワード', }, - nTo1RetrievalLegacy: '製品計画によると、N To 1 Retrievalは9月に正式に廃止される予定です。それまでは通常通り使用できます。', + nTo1RetrievalLegacy: '製品計画によると、N-to-1 Retrievalは9月に正式に廃止される予定です。それまでは通常通り使用できます。', + nTo1RetrievalLegacyLink: '詳細を見る', + nTo1RetrievalLegacyLinkText: ' N-to-1 retrievalは9月に正式に廃止されます。', } export default translation From a24717765e5dd63519549a893230ae24fe6a4269 Mon Sep 17 00:00:00 2001 From: orangeclk Date: Fri, 23 Aug 2024 11:15:38 +0800 Subject: [PATCH 04/24] feat: forward zhipu finish_reason (#7560) --- api/core/model_runtime/model_providers/zhipuai/llm/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index ff971964a8..13d8f5e5c3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -444,6 +444,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, + finish_reason=delta.finish_reason ) ) From 3ac8a2871e2ea88d46f90ccd4fdc3f1d71c527af Mon Sep 17 00:00:00 2001 From: Nam Vu Date: Fri, 23 Aug 2024 10:16:37 +0700 Subject: [PATCH 05/24] chore: #6554 i18n (#7562) --- web/i18n/de-DE/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/es-ES/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/fr-FR/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/hi-IN/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/it-IT/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/ko-KR/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/pl-PL/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/pt-BR/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/ro-RO/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/tr-TR/dataset.ts | 2 ++ web/i18n/uk-UA/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/vi-VN/dataset.ts | 26 ++++++++++++++++++++++++++ web/i18n/zh-Hant/dataset.ts | 26 ++++++++++++++++++++++++++ 13 files changed, 314 insertions(+) diff --git a/web/i18n/de-DE/dataset.ts b/web/i18n/de-DE/dataset.ts index 53e8cdd447..c6586ceee8 100644 --- a/web/i18n/de-DE/dataset.ts +++ b/web/i18n/de-DE/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'Dokumente konnten nicht indiziert werden', retry: 'Wiederholen', + indexingTechnique: { + high_quality: 'HQ', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VEKTOR', + full_text_search: 'VOLLTEXT', + hybrid_search: 'HYBRID', + invertedIndex: 'INVERTIERT', + }, + mixtureHighQualityAndEconomicTip: 'Für die Mischung von hochwertigen und wirtschaftlichen Wissensbasen ist das Rerank-Modell erforderlich.', + inconsistentEmbeddingModelTip: 'Das Rerank-Modell ist erforderlich, wenn die Embedding-Modelle der ausgewählten Wissensbasen inkonsistent sind.', + retrievalSettings: 'Abrufeinstellungen', + rerankSettings: 'Rerank-Einstellungen', + weightedScore: { + title: 'Gewichtete Bewertung', + description: 'Durch Anpassung der zugewiesenen Gewichte bestimmt diese Rerank-Strategie, ob semantische oder Schlüsselwort-Übereinstimmung priorisiert werden soll.', + semanticFirst: 'Semantik zuerst', + keywordFirst: 'Schlüsselwort zuerst', + customized: 'Angepasst', + semantic: 'Semantisch', + keyword: 'Schlüsselwort', + }, + nTo1RetrievalLegacy: 'N-zu-1-Abruf wird ab September offiziell eingestellt. Es wird empfohlen, den neuesten Multi-Pfad-Abruf zu verwenden, um bessere Ergebnisse zu erzielen.', + nTo1RetrievalLegacyLink: 'Mehr erfahren', + nTo1RetrievalLegacyLinkText: 'N-zu-1-Abruf wird im September offiziell eingestellt.', } export default translation diff --git a/web/i18n/es-ES/dataset.ts b/web/i18n/es-ES/dataset.ts index 307187b605..e4fc362efa 100644 --- a/web/i18n/es-ES/dataset.ts +++ b/web/i18n/es-ES/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'no se pudieron indexar los documentos', retry: 'Reintentar', + indexingTechnique: { + high_quality: 'AC', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VECTOR', + full_text_search: 'TEXTO COMPLETO', + hybrid_search: 'HÍBRIDO', + invertedIndex: 'INVERTIDO', + }, + mixtureHighQualityAndEconomicTip: 'Se requiere el modelo de reclasificación para la mezcla de bases de conocimiento de alta calidad y económicas.', + inconsistentEmbeddingModelTip: 'Se requiere el modelo de reclasificación si los modelos de incrustación de las bases de conocimiento seleccionadas son inconsistentes.', + retrievalSettings: 'Configuración de recuperación', + rerankSettings: 'Configuración de reclasificación', + weightedScore: { + title: 'Puntuación ponderada', + description: 'Al ajustar los pesos asignados, esta estrategia de reclasificación determina si se debe priorizar la coincidencia semántica o de palabras clave.', + semanticFirst: 'Semántica primero', + keywordFirst: 'Palabra clave primero', + customized: 'Personalizado', + semantic: 'Semántico', + keyword: 'Palabra clave', + }, + nTo1RetrievalLegacy: 'La recuperación N-a-1 será oficialmente obsoleta a partir de septiembre. Se recomienda utilizar la última recuperación de múltiples rutas para obtener mejores resultados.', + nTo1RetrievalLegacyLink: 'Más información', + nTo1RetrievalLegacyLinkText: 'La recuperación N-a-1 será oficialmente obsoleta en septiembre.', } export default translation diff --git a/web/i18n/fr-FR/dataset.ts b/web/i18n/fr-FR/dataset.ts index 2ba5819c24..014168e006 100644 --- a/web/i18n/fr-FR/dataset.ts +++ b/web/i18n/fr-FR/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'Les documents n\'ont pas pu être indexés', retry: 'Réessayer', + indexingTechnique: { + high_quality: 'HQ', + economy: 'ÉCO', + }, + indexingMethod: { + semantic_search: 'VECTEUR', + full_text_search: 'TEXTE INTÉGRAL', + hybrid_search: 'HYBRIDE', + invertedIndex: 'INVERSÉ', + }, + mixtureHighQualityAndEconomicTip: 'Le modèle de reclassement est nécessaire pour le mélange de bases de connaissances de haute qualité et économiques.', + inconsistentEmbeddingModelTip: 'Le modèle de reclassement est nécessaire si les modèles d\'incorporation des bases de connaissances sélectionnées sont incohérents.', + retrievalSettings: 'Paramètres de récupération', + rerankSettings: 'Paramètres de reclassement', + weightedScore: { + title: 'Score pondéré', + description: 'En ajustant les poids attribués, cette stratégie de reclassement détermine s\'il faut prioriser la correspondance sémantique ou par mots-clés.', + semanticFirst: 'Sémantique d\'abord', + keywordFirst: 'Mot-clé d\'abord', + customized: 'Personnalisé', + semantic: 'Sémantique', + keyword: 'Mot-clé', + }, + nTo1RetrievalLegacy: 'La récupération N-à-1 sera officiellement obsolète à partir de septembre. Il est recommandé d\'utiliser la dernière récupération multi-chemins pour obtenir de meilleurs résultats.', + nTo1RetrievalLegacyLink: 'En savoir plus', + nTo1RetrievalLegacyLinkText: 'La récupération N-à-1 sera officiellement obsolète en septembre.', } export default translation diff --git a/web/i18n/hi-IN/dataset.ts b/web/i18n/hi-IN/dataset.ts index 859887a54d..de33113d2b 100644 --- a/web/i18n/hi-IN/dataset.ts +++ b/web/i18n/hi-IN/dataset.ts @@ -52,6 +52,32 @@ const translation = { }, docsFailedNotice: 'दस्तावेज़ों को अनुक्रमित करने में विफल', retry: 'पुनः प्रयास करें', + indexingTechnique: { + high_quality: 'उच्च गुणवत्ता', + economy: 'किफायती', + }, + indexingMethod: { + semantic_search: 'वेक्टर', + full_text_search: 'पूर्ण पाठ', + hybrid_search: 'हाइब्रिड', + invertedIndex: 'उल्टा', + }, + mixtureHighQualityAndEconomicTip: 'उच्च गुणवत्ता और किफायती ज्ञान आधारों के मिश्रण के लिए पुनः रैंकिंग मॉडल आवश्यक है।', + inconsistentEmbeddingModelTip: 'यदि चयनित ज्ञान आधारों के एम्बेडिंग मॉडल असंगत हैं, तो पुनः रैंकिंग मॉडल आवश्यक है।', + retrievalSettings: 'पुनर्प्राप्ति सेटिंग्स', + rerankSettings: 'पुनः रैंकिंग सेटिंग्स', + weightedScore: { + title: 'भारित स्कोर', + description: 'आवंटित भारों को समायोजित करके, यह पुनः रैंकिंग रणनीति निर्धारित करती है कि शब्दार्थ या कीवर्ड मिलान को प्राथमिकता दी जाए।', + semanticFirst: 'शब्दार्थ पहले', + keywordFirst: 'कीवर्ड पहले', + customized: 'अनुकूलित', + semantic: 'शब्दार्थ', + keyword: 'कीवर्ड', + }, + nTo1RetrievalLegacy: 'N-से-1 पुनर्प्राप्ति सितंबर से आधिकारिक तौर पर बंद कर दी जाएगी। बेहतर परिणाम प्राप्त करने के लिए नवीनतम बहु-मार्ग पुनर्प्राप्ति का उपयोग करने की सिफारिश की जाती है।', + nTo1RetrievalLegacyLink: 'और जानें', + nTo1RetrievalLegacyLinkText: 'N-से-1 पुनर्प्राप्ति सितंबर में आधिकारिक तौर पर बंद कर दी जाएगी।', } export default translation diff --git a/web/i18n/it-IT/dataset.ts b/web/i18n/it-IT/dataset.ts index f191f6f2a6..9223a3a96d 100644 --- a/web/i18n/it-IT/dataset.ts +++ b/web/i18n/it-IT/dataset.ts @@ -52,6 +52,32 @@ const translation = { }, docsFailedNotice: 'documenti non riusciti a essere indicizzati', retry: 'Riprova', + indexingTechnique: { + high_quality: 'AQ', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VETTORE', + full_text_search: 'TESTO COMPLETO', + hybrid_search: 'IBRIDO', + invertedIndex: 'INVERTITO', + }, + mixtureHighQualityAndEconomicTip: 'Il modello di riclassificazione è necessario per la miscela di basi di conoscenza di alta qualità ed economiche.', + inconsistentEmbeddingModelTip: 'Il modello di riclassificazione è necessario se i modelli di embedding delle basi di conoscenza selezionate sono incoerenti.', + retrievalSettings: 'Impostazioni di recupero', + rerankSettings: 'Impostazioni di riclassificazione', + weightedScore: { + title: 'Punteggio ponderato', + description: 'Regolando i pesi assegnati, questa strategia di riclassificazione determina se dare priorità alla corrispondenza semantica o per parole chiave.', + semanticFirst: 'Semantica prima', + keywordFirst: 'Parola chiave prima', + customized: 'Personalizzato', + semantic: 'Semantico', + keyword: 'Parola chiave', + }, + nTo1RetrievalLegacy: 'Il recupero N-a-1 sarà ufficialmente deprecato da settembre. Si consiglia di utilizzare il più recente recupero multi-percorso per ottenere risultati migliori.', + nTo1RetrievalLegacyLink: 'Scopri di più', + nTo1RetrievalLegacyLinkText: 'Il recupero N-a-1 sarà ufficialmente deprecato a settembre.', } export default translation diff --git a/web/i18n/ko-KR/dataset.ts b/web/i18n/ko-KR/dataset.ts index 27a5d7320e..907a1f21b6 100644 --- a/web/i18n/ko-KR/dataset.ts +++ b/web/i18n/ko-KR/dataset.ts @@ -44,6 +44,32 @@ const translation = { }, docsFailedNotice: '문서 인덱스에 실패했습니다', retry: '재시도', + indexingTechnique: { + high_quality: 'HQ', + economy: '이코노미', + }, + indexingMethod: { + semantic_search: '벡터', + full_text_search: '전체 텍스트', + hybrid_search: '하이브리드', + invertedIndex: '역인덱스', + }, + mixtureHighQualityAndEconomicTip: '고품질과 경제적 지식 베이스의 혼합을 위해서는 재순위 모델이 필요합니다.', + inconsistentEmbeddingModelTip: '선택된 지식 베이스의 임베딩 모델이 일관되지 않은 경우 재순위 모델이 필요합니다.', + retrievalSettings: '검색 설정', + rerankSettings: '재순위 설정', + weightedScore: { + title: '가중 점수', + description: '할당된 가중치를 조정함으로써, 이 재순위 전략은 의미론적 일치 또는 키워드 일치 중 어느 것을 우선시할지 결정합니다.', + semanticFirst: '의미론 우선', + keywordFirst: '키워드 우선', + customized: '사용자 정의', + semantic: '의미론적', + keyword: '키워드', + }, + nTo1RetrievalLegacy: 'N-대-1 검색은 9월부터 공식적으로 더 이상 사용되지 않습니다. 더 나은 결과를 얻으려면 최신 다중 경로 검색을 사용하는 것이 좋습니다.', + nTo1RetrievalLegacyLink: '자세히 알아보기', + nTo1RetrievalLegacyLinkText: 'N-대-1 검색은 9월에 공식적으로 더 이상 사용되지 않습니다.', } export default translation diff --git a/web/i18n/pl-PL/dataset.ts b/web/i18n/pl-PL/dataset.ts index 5351b3c739..14de4eaf40 100644 --- a/web/i18n/pl-PL/dataset.ts +++ b/web/i18n/pl-PL/dataset.ts @@ -51,6 +51,32 @@ const translation = { }, docsFailedNotice: 'nie udało się zindeksować dokumentów', retry: 'Ponów', + indexingTechnique: { + high_quality: 'WJ', + economy: 'EKO', + }, + indexingMethod: { + semantic_search: 'WEKTOR', + full_text_search: 'PEŁNY TEKST', + hybrid_search: 'HYBRYDOWY', + invertedIndex: 'ODWRÓCONY', + }, + mixtureHighQualityAndEconomicTip: 'Model ponownego rankingu jest wymagany dla mieszanki wysokiej jakości i ekonomicznych baz wiedzy.', + inconsistentEmbeddingModelTip: 'Model ponownego rankingu jest wymagany, jeśli modele osadzania wybranych baz wiedzy są niespójne.', + retrievalSettings: 'Ustawienia wyszukiwania', + rerankSettings: 'Ustawienia ponownego rankingu', + weightedScore: { + title: 'Ważona ocena', + description: 'Poprzez dostosowanie przypisanych wag, ta strategia ponownego rankingu określa, czy priorytetowo traktować dopasowanie semantyczne czy słów kluczowych.', + semanticFirst: 'Najpierw semantyczne', + keywordFirst: 'Najpierw słowa kluczowe', + customized: 'Dostosowane', + semantic: 'Semantyczne', + keyword: 'Słowo kluczowe', + }, + nTo1RetrievalLegacy: 'Wyszukiwanie N-do-1 zostanie oficjalnie wycofane od września. Zaleca się korzystanie z najnowszego wyszukiwania wielościeżkowego, aby uzyskać lepsze wyniki.', + nTo1RetrievalLegacyLink: 'Dowiedz się więcej', + nTo1RetrievalLegacyLinkText: 'Wyszukiwanie N-do-1 zostanie oficjalnie wycofane we wrześniu.', } export default translation diff --git a/web/i18n/pt-BR/dataset.ts b/web/i18n/pt-BR/dataset.ts index f6f76e4626..8710879149 100644 --- a/web/i18n/pt-BR/dataset.ts +++ b/web/i18n/pt-BR/dataset.ts @@ -44,6 +44,32 @@ const translation = { }, docsFailedNotice: 'documentos falharam ao serem indexados', retry: 'Tentar novamente', + indexingTechnique: { + high_quality: 'AQ', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VETOR', + full_text_search: 'TEXTO COMPLETO', + hybrid_search: 'HÍBRIDO', + invertedIndex: 'INVERTIDO', + }, + mixtureHighQualityAndEconomicTip: 'O modelo de reclassificação é necessário para a mistura de bases de conhecimento de alta qualidade e econômicas.', + inconsistentEmbeddingModelTip: 'O modelo de reclassificação é necessário se os modelos de incorporação das bases de conhecimento selecionadas forem inconsistentes.', + retrievalSettings: 'Configurações de Recuperação', + rerankSettings: 'Configurações de Reclassificação', + weightedScore: { + title: 'Pontuação Ponderada', + description: 'Ao ajustar os pesos atribuídos, esta estratégia de reclassificação determina se deve priorizar a correspondência semântica ou por palavras-chave.', + semanticFirst: 'Semântica primeiro', + keywordFirst: 'Palavra-chave primeiro', + customized: 'Personalizado', + semantic: 'Semântico', + keyword: 'Palavra-chave', + }, + nTo1RetrievalLegacy: 'A recuperação N-para-1 será oficialmente descontinuada a partir de setembro. Recomenda-se usar a recuperação de múltiplos caminhos mais recente para obter melhores resultados.', + nTo1RetrievalLegacyLink: 'Saiba mais', + nTo1RetrievalLegacyLinkText: 'A recuperação N-para-1 será oficialmente descontinuada em setembro.', } export default translation diff --git a/web/i18n/ro-RO/dataset.ts b/web/i18n/ro-RO/dataset.ts index 363d882b09..3e605baf92 100644 --- a/web/i18n/ro-RO/dataset.ts +++ b/web/i18n/ro-RO/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'documentele nu au putut fi indexate', retry: 'Reîncercați', + indexingTechnique: { + high_quality: 'IC', + economy: 'ECO', + }, + indexingMethod: { + semantic_search: 'VECTOR', + full_text_search: 'TEXT COMPLET', + hybrid_search: 'HIBRID', + invertedIndex: 'INVERSAT', + }, + mixtureHighQualityAndEconomicTip: 'Modelul de reclasificare este necesar pentru amestecul de baze de cunoștințe de înaltă calitate și economice.', + inconsistentEmbeddingModelTip: 'Modelul de reclasificare este necesar dacă modelele de încorporare ale bazelor de cunoștințe selectate sunt inconsistente.', + retrievalSettings: 'Setări de recuperare', + rerankSettings: 'Setări de reclasificare', + weightedScore: { + title: 'Scor ponderat', + description: 'Prin ajustarea ponderilor atribuite, această strategie de reclasificare determină dacă să prioritizeze potrivirea semantică sau pe cea a cuvintelor cheie.', + semanticFirst: 'Semantic primul', + keywordFirst: 'Cuvânt cheie primul', + customized: 'Personalizat', + semantic: 'Semantic', + keyword: 'Cuvânt cheie', + }, + nTo1RetrievalLegacy: 'Recuperarea N-la-1 va fi oficial depreciată din septembrie. Se recomandă utilizarea celei mai recente recuperări cu căi multiple pentru a obține rezultate mai bune.', + nTo1RetrievalLegacyLink: 'Află mai multe', + nTo1RetrievalLegacyLinkText: 'Recuperarea N-la-1 va fi oficial depreciată în septembrie.', } export default translation diff --git a/web/i18n/tr-TR/dataset.ts b/web/i18n/tr-TR/dataset.ts index 68bd5860bb..31d483f504 100644 --- a/web/i18n/tr-TR/dataset.ts +++ b/web/i18n/tr-TR/dataset.ts @@ -68,6 +68,8 @@ const translation = { keyword: 'Anahtar Kelime', }, nTo1RetrievalLegacy: 'Geri alım stratejisinin optimizasyonu ve yükseltilmesi nedeniyle, N-to-1 geri alımı Eylül ayında resmi olarak kullanım dışı kalacaktır. O zamana kadar normal şekilde kullanabilirsiniz.', + nTo1RetrievalLegacyLink: 'Daha fazla bilgi edin', + nTo1RetrievalLegacyLinkText: 'N-1 geri alma Eylül ayında resmi olarak kullanımdan kaldırılacaktır.', } export default translation diff --git a/web/i18n/uk-UA/dataset.ts b/web/i18n/uk-UA/dataset.ts index fb44b4107a..3bf59ed33b 100644 --- a/web/i18n/uk-UA/dataset.ts +++ b/web/i18n/uk-UA/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'документи не вдалося проіндексувати', retry: 'Повторити спробу', + indexingTechnique: { + high_quality: 'ВЯ', + economy: 'ЕКО', + }, + indexingMethod: { + semantic_search: 'ВЕКТОР', + full_text_search: 'ПОВНИЙ ТЕКСТ', + hybrid_search: 'ГІБРИД', + invertedIndex: 'ІНВЕРТОВАНИЙ', + }, + mixtureHighQualityAndEconomicTip: 'Модель перерангування потрібна для суміші високоякісних та економічних баз знань.', + inconsistentEmbeddingModelTip: 'Модель перерангування потрібна, якщо моделі вбудовування вибраних баз знань є несумісними.', + retrievalSettings: 'Налаштування пошуку', + rerankSettings: 'Налаштування перерангування', + weightedScore: { + title: 'Зважена оцінка', + description: 'Регулюючи призначені ваги, ця стратегія перерангування визначає, чи надавати пріоритет семантичному чи ключовому відповідності.', + semanticFirst: 'Спочатку семантичний', + keywordFirst: 'Спочатку ключове слово', + customized: 'Налаштований', + semantic: 'Семантичний', + keyword: 'Ключове слово', + }, + nTo1RetrievalLegacy: 'N-до-1 пошук буде офіційно застарілим з вересня. Рекомендується використовувати найновіший багатошляховий пошук для отримання кращих результатів.', + nTo1RetrievalLegacyLink: 'Дізнатися більше', + nTo1RetrievalLegacyLinkText: 'N-до-1 пошук буде офіційно застарілим у вересні.', } export default translation diff --git a/web/i18n/vi-VN/dataset.ts b/web/i18n/vi-VN/dataset.ts index 27dad01f51..81b4597800 100644 --- a/web/i18n/vi-VN/dataset.ts +++ b/web/i18n/vi-VN/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: 'tài liệu không được lập chỉ mục', retry: 'Thử lại', + indexingTechnique: { + high_quality: 'CHẤT LƯỢNG', + economy: 'TIẾT KIỆM', + }, + indexingMethod: { + semantic_search: 'VECTOR', + full_text_search: 'VĂN BẢN ĐẦY ĐỦ', + hybrid_search: 'KẾT HỢP', + invertedIndex: 'ĐẢO NGƯỢC', + }, + mixtureHighQualityAndEconomicTip: 'Mô hình xếp hạng lại là cần thiết cho sự kết hợp của các cơ sở kiến thức chất lượng cao và tiết kiệm.', + inconsistentEmbeddingModelTip: 'Mô hình xếp hạng lại là cần thiết nếu các mô hình nhúng của các cơ sở kiến thức được chọn không nhất quán.', + retrievalSettings: 'Cài đặt truy xuất', + rerankSettings: 'Cài đặt xếp hạng lại', + weightedScore: { + title: 'Điểm số có trọng số', + description: 'Bằng cách điều chỉnh trọng số được gán, chiến lược xếp hạng lại này xác định liệu ưu tiên khớp ngữ nghĩa hay từ khóa.', + semanticFirst: 'Ngữ nghĩa trước', + keywordFirst: 'Từ khóa trước', + customized: 'Tùy chỉnh', + semantic: 'Ngữ nghĩa', + keyword: 'Từ khóa', + }, + nTo1RetrievalLegacy: 'Truy xuất N-đến-1 sẽ chính thức bị loại bỏ từ tháng 9. Khuyến nghị sử dụng truy xuất đa đường dẫn mới nhất để có kết quả tốt hơn.', + nTo1RetrievalLegacyLink: 'Tìm hiểu thêm', + nTo1RetrievalLegacyLinkText: 'Truy xuất N-đến-1 sẽ chính thức bị loại bỏ vào tháng 9.', } export default translation diff --git a/web/i18n/zh-Hant/dataset.ts b/web/i18n/zh-Hant/dataset.ts index 8de7bc487f..1e011bc987 100644 --- a/web/i18n/zh-Hant/dataset.ts +++ b/web/i18n/zh-Hant/dataset.ts @@ -45,6 +45,32 @@ const translation = { }, docsFailedNotice: '文件無法被索引', retry: '重試', + indexingTechnique: { + high_quality: '高質量', + economy: '經濟', + }, + indexingMethod: { + semantic_search: '向量', + full_text_search: '全文', + hybrid_search: '混合', + invertedIndex: '倒排索引', + }, + mixtureHighQualityAndEconomicTip: '混合高質量和經濟知識庫需要重新排序模型。', + inconsistentEmbeddingModelTip: '如果選定知識庫的嵌入模型不一致,則需要重新排序模型。', + retrievalSettings: '檢索設置', + rerankSettings: '重新排序設置', + weightedScore: { + title: '加權分數', + description: '通過調整分配的權重,此重新排序策略決定是優先考慮語義匹配還是關鍵詞匹配。', + semanticFirst: '語義優先', + keywordFirst: '關鍵詞優先', + customized: '自定義', + semantic: '語義', + keyword: '關鍵詞', + }, + nTo1RetrievalLegacy: 'N對1檢索將從9月起正式棄用。建議使用最新的多路徑檢索以獲得更好的結果。', + nTo1RetrievalLegacyLink: '了解更多', + nTo1RetrievalLegacyLinkText: 'N對1檢索將於9月正式棄用。', } export default translation From a71fc185309f8b1fdc1dc079b3f6185848189091 Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:11:55 +0800 Subject: [PATCH 06/24] feat: set workflow zoom range for shortcut (#7563) --- .../workflow/hooks/use-shortcuts.ts | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 9484f9c16e..666c3a45ba 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -37,12 +37,25 @@ export const useShortcuts = (): void => { const { handleLayout } = useWorkflowOrganize() const { - zoomIn, - zoomOut, zoomTo, + getZoom, fitView, } = useReactFlow() + // Zoom out to a minimum of 0.5 for shortcut + const constrainedZoomOut = () => { + const currentZoom = getZoom() + const newZoom = Math.max(currentZoom - 0.1, 0.5) + zoomTo(newZoom) + } + + // Zoom in to a maximum of 1 for shortcut + const constrainedZoomIn = () => { + const currentZoom = getZoom() + const newZoom = Math.min(currentZoom + 0.1, 1) + zoomTo(newZoom) + } + const shouldHandleShortcut = useCallback((e: KeyboardEvent) => { const { showFeaturesPanel } = workflowStore.getState() return !showFeaturesPanel && !isEventTargetInputArea(e.target as HTMLElement) @@ -165,7 +178,7 @@ export const useShortcuts = (): void => { useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.dash`, (e) => { if (shouldHandleShortcut(e)) { e.preventDefault() - zoomOut() + constrainedZoomOut() handleSyncWorkflowDraft() } }, { @@ -176,7 +189,7 @@ export const useShortcuts = (): void => { useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.equalsign`, (e) => { if (shouldHandleShortcut(e)) { e.preventDefault() - zoomIn() + constrainedZoomIn() handleSyncWorkflowDraft() } }, { From 9618f86980a69ffaa5f48373267dc5aa22c3771a Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:14:17 +0800 Subject: [PATCH 07/24] fix: workflow context menu popup issue (#7530) --- web/app/components/workflow/node-contextmenu.tsx | 12 +++++++++--- web/app/components/workflow/panel-contextmenu.tsx | 8 +++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/web/app/components/workflow/node-contextmenu.tsx b/web/app/components/workflow/node-contextmenu.tsx index adfed37b26..311bf1fddf 100644 --- a/web/app/components/workflow/node-contextmenu.tsx +++ b/web/app/components/workflow/node-contextmenu.tsx @@ -1,5 +1,6 @@ import { memo, + useEffect, useRef, } from 'react' import { useClickAway } from 'ahooks' @@ -9,13 +10,18 @@ import type { Node } from './types' import { useStore } from './store' import { usePanelInteractions } from './hooks' -const PanelContextmenu = () => { +const NodeContextmenu = () => { const ref = useRef(null) const nodes = useNodes() - const { handleNodeContextmenuCancel } = usePanelInteractions() + const { handleNodeContextmenuCancel, handlePaneContextmenuCancel } = usePanelInteractions() const nodeMenu = useStore(s => s.nodeMenu) const currentNode = nodes.find(node => node.id === nodeMenu?.nodeId) as Node + useEffect(() => { + if (nodeMenu) + handlePaneContextmenuCancel() + }, [nodeMenu, handlePaneContextmenuCancel]) + useClickAway(() => { handleNodeContextmenuCancel() }, ref) @@ -42,4 +48,4 @@ const PanelContextmenu = () => { ) } -export default memo(PanelContextmenu) +export default memo(NodeContextmenu) diff --git a/web/app/components/workflow/panel-contextmenu.tsx b/web/app/components/workflow/panel-contextmenu.tsx index 502967ce2c..f01e3037a2 100644 --- a/web/app/components/workflow/panel-contextmenu.tsx +++ b/web/app/components/workflow/panel-contextmenu.tsx @@ -1,5 +1,6 @@ import { memo, + useEffect, useRef, } from 'react' import { useTranslation } from 'react-i18next' @@ -23,11 +24,16 @@ const PanelContextmenu = () => { const clipboardElements = useStore(s => s.clipboardElements) const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal) const { handleNodesPaste } = useNodesInteractions() - const { handlePaneContextmenuCancel } = usePanelInteractions() + const { handlePaneContextmenuCancel, handleNodeContextmenuCancel } = usePanelInteractions() const { handleStartWorkflowRun } = useWorkflowStartRun() const { handleAddNote } = useOperator() const { exportCheck } = useDSL() + useEffect(() => { + if (panelMenu) + handleNodeContextmenuCancel() + }, [panelMenu, handleNodeContextmenuCancel]) + useClickAway(() => { handlePaneContextmenuCancel() }, ref) From 0a7ab9a47dc4858ed6d0cf8279add6ad299b6031 Mon Sep 17 00:00:00 2001 From: edo1z <89882017+edo1z@users.noreply.github.com> Date: Fri, 23 Aug 2024 14:16:15 +0900 Subject: [PATCH 08/24] fix: incorrect duplication when no target node is selected (#7539) --- .../workflow/hooks/use-nodes-interactions.ts | 30 ++++++++++++------- .../panel-operator/panel-operator-popup.tsx | 2 +- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 87d1b4de8c..3645e18449 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1027,7 +1027,7 @@ export const useNodesInteractions = () => { handleNodeSelect(node.id) }, [workflowStore, handleNodeSelect]) - const handleNodesCopy = useCallback(() => { + const handleNodesCopy = useCallback((nodeId?: string) => { if (getNodesReadOnly()) return @@ -1038,17 +1038,27 @@ export const useNodesInteractions = () => { } = store.getState() const nodes = getNodes() - const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start && !node.data.isInIteration) - if (bundledNodes.length) { - setClipboardElements(bundledNodes) - return + if (nodeId) { + // If nodeId is provided, copy that specific node + const nodeToCopy = nodes.find(node => node.id === nodeId && node.data.type !== BlockEnum.Start) + if (nodeToCopy) + setClipboardElements([nodeToCopy]) } + else { + // If no nodeId is provided, fall back to the current behavior + const bundledNodes = nodes.filter(node => node.data._isBundled && node.data.type !== BlockEnum.Start && !node.data.isInIteration) - const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start) + if (bundledNodes.length) { + setClipboardElements(bundledNodes) + return + } - if (selectedNode) - setClipboardElements([selectedNode]) + const selectedNode = nodes.find(node => node.data.selected && node.data.type !== BlockEnum.Start) + + if (selectedNode) + setClipboardElements([selectedNode]) + } }, [getNodesReadOnly, store, workflowStore]) const handleNodesPaste = useCallback(() => { @@ -1128,11 +1138,11 @@ export const useNodesInteractions = () => { } }, [getNodesReadOnly, workflowStore, store, reactflow, saveStateToHistory, handleSyncWorkflowDraft, handleNodeIterationChildrenCopy]) - const handleNodesDuplicate = useCallback(() => { + const handleNodesDuplicate = useCallback((nodeId?: string) => { if (getNodesReadOnly()) return - handleNodesCopy() + handleNodesCopy(nodeId) handleNodesPaste() }, [getNodesReadOnly, handleNodesCopy, handleNodesPaste]) diff --git a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx index aade4d8be8..bd642fcd66 100644 --- a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx @@ -138,7 +138,7 @@ const PanelOperatorPopup = ({ className='flex items-center justify-between px-3 h-8 text-sm text-gray-700 rounded-lg cursor-pointer hover:bg-gray-50' onClick={() => { onClosePopup() - handleNodesDuplicate() + handleNodesDuplicate(id) }} > {t('workflow.common.duplicate')} From 399d7cd596e2d276e7c6df366c92293488c6e3e9 Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 23 Aug 2024 14:30:26 +0800 Subject: [PATCH 09/24] chore: improve the check time of variable name (#7569) --- .../config-var/config-modal/index.tsx | 37 ++++++++++--------- web/i18n/de-DE/app-debug.ts | 10 ++--- web/i18n/en-US/app-debug.ts | 11 +++--- web/i18n/es-ES/app-debug.ts | 10 ++--- web/i18n/fa-IR/app-debug.ts | 10 ++--- web/i18n/fr-FR/app-debug.ts | 10 ++--- web/i18n/hi-IN/app-debug.ts | 10 ++--- web/i18n/it-IT/app-debug.ts | 10 ++--- web/i18n/ja-JP/app-debug.ts | 10 ++--- web/i18n/ko-KR/app-debug.ts | 10 ++--- web/i18n/pl-PL/app-debug.ts | 10 ++--- web/i18n/pt-BR/app-debug.ts | 10 ++--- web/i18n/ro-RO/app-debug.ts | 10 ++--- web/i18n/tr-TR/app-debug.ts | 10 ++--- web/i18n/uk-UA/app-debug.ts | 10 ++--- web/i18n/vi-VN/app-debug.ts | 10 ++--- web/i18n/zh-Hans/app-debug.ts | 11 +++--- web/i18n/zh-Hant/app-debug.ts | 10 ++--- 18 files changed, 105 insertions(+), 104 deletions(-) diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 20fcf49de1..3296c77fb2 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -42,18 +42,19 @@ const ConfigModal: FC = ({ const { type, label, variable, options, max_length } = tempPayload const isStringInput = type === InputVarType.textInput || type === InputVarType.paragraph + const checkVariableName = useCallback((value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConig.varName') }), + }) + return false + } + return true + }, [t]) const handlePayloadChange = useCallback((key: string) => { return (value: any) => { - if (key === 'variable') { - const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) - if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }), - }) - return - } - } setTempPayload((prev) => { const newPayload = { ...prev, @@ -63,19 +64,20 @@ const ConfigModal: FC = ({ return newPayload }) } - }, [t]) + }, []) const handleVarKeyBlur = useCallback((e: any) => { - if (tempPayload.label) + const varName = e.target.value + if (!checkVariableName(varName) || tempPayload.label) return setTempPayload((prev) => { return { ...prev, - label: e.target.value, + label: varName, } }) - }, [tempPayload]) + }, [checkVariableName, tempPayload.label]) const handleConfirm = () => { const moreInfo = tempPayload.variable === payload?.variable @@ -84,10 +86,11 @@ const ConfigModal: FC = ({ type: ChangeType.changeVarName, payload: { beforeKey: payload?.variable || '', afterKey: tempPayload.variable }, } - if (!tempPayload.variable) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.varNameRequired') }) + + const isVariableNameValid = checkVariableName(tempPayload.variable) + if (!isVariableNameValid) return - } + // TODO: check if key already exists. should the consider the edit case // if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) { // Toast.notify({ diff --git a/web/i18n/de-DE/app-debug.ts b/web/i18n/de-DE/app-debug.ts index e78b6c0d7a..acb3f53904 100644 --- a/web/i18n/de-DE/app-debug.ts +++ b/web/i18n/de-DE/app-debug.ts @@ -237,11 +237,11 @@ const translation = { typeSelect: 'Auswählen', }, varKeyError: { - canNoBeEmpty: 'Variablenschlüssel darf nicht leer sein', - tooLong: 'Variablenschlüssel: {{key}} zu lang. Darf nicht länger als 30 Zeichen sein', - notValid: 'Variablenschlüssel: {{key}} ist ungültig. Darf nur Buchstaben, Zahlen und Unterstriche enthalten', - notStartWithNumber: 'Variablenschlüssel: {{key}} darf nicht mit einer Zahl beginnen', - keyAlreadyExists: 'Variablenschlüssel: :{{key}} existiert bereits', + canNoBeEmpty: '{{key}} ist erforderlich', + tooLong: '{{key}} zu lang. Darf nicht länger als 30 Zeichen sein', + notValid: '{{key}} ist ungültig. Darf nur Buchstaben, Zahlen und Unterstriche enthalten', + notStartWithNumber: '{{key}} darf nicht mit einer Zahl beginnen', + keyAlreadyExists: '{{key}} existiert bereits', }, otherError: { promptNoBeEmpty: 'Prompt darf nicht leer sein', diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index a4e8a4f7fa..86c5f720c3 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -290,11 +290,11 @@ const translation = { typeSelect: 'Select', }, varKeyError: { - canNoBeEmpty: 'Variable key can not be empty', - tooLong: 'Variable key: {{key}} too length. Can not be longer then 30 characters', - notValid: 'Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores', - notStartWithNumber: 'Variable key: {{key}} can not start with a number', - keyAlreadyExists: 'Variable key: :{{key}} already exists', + canNoBeEmpty: '{{key}} is required', + tooLong: '{{key}} is too length. Can not be longer then 30 characters', + notValid: '{{key}} is invalid. Can only contain letters, numbers, and underscores', + notStartWithNumber: '{{key}} can not start with a number', + keyAlreadyExists: '{{key}} already exists', }, otherError: { promptNoBeEmpty: 'Prompt can not be empty', @@ -323,7 +323,6 @@ const translation = { 'content': 'Content', 'required': 'Required', 'errorMsg': { - varNameRequired: 'Variable name is required', labelNameRequired: 'Label name is required', varNameCanBeRepeat: 'Variable name can not be repeated', atLeastOneOption: 'At least one option is required', diff --git a/web/i18n/es-ES/app-debug.ts b/web/i18n/es-ES/app-debug.ts index 9e309a7d62..68088c26a6 100644 --- a/web/i18n/es-ES/app-debug.ts +++ b/web/i18n/es-ES/app-debug.ts @@ -248,11 +248,11 @@ const translation = { typeSelect: 'Seleccionar', }, varKeyError: { - canNoBeEmpty: 'La clave de la variable no puede estar vacía', - tooLong: 'Clave de la variable: {{key}} demasiado larga. No puede tener más de 30 caracteres', - notValid: 'Clave de la variable: {{key}} no es válida. Solo puede contener letras, números y guiones bajos', - notStartWithNumber: 'Clave de la variable: {{key}} no puede comenzar con un número', - keyAlreadyExists: 'Clave de la variable: {{key}} ya existe', + canNoBeEmpty: 'Se requiere {{key}}', + tooLong: '{{key}} demasiado larga. No puede tener más de 30 caracteres', + notValid: '{{key}} no es válida. Solo puede contener letras, números y guiones bajos', + notStartWithNumber: '{{key}} no puede comenzar con un número', + keyAlreadyExists: '{{key}} ya existe', }, otherError: { promptNoBeEmpty: 'La indicación no puede estar vacía', diff --git a/web/i18n/fa-IR/app-debug.ts b/web/i18n/fa-IR/app-debug.ts index 863f47bb18..1ce222581d 100644 --- a/web/i18n/fa-IR/app-debug.ts +++ b/web/i18n/fa-IR/app-debug.ts @@ -283,11 +283,11 @@ const translation = { typeSelect: 'انتخاب', }, varKeyError: { - canNoBeEmpty: 'کلید متغیر نمی‌تواند خالی باشد', - tooLong: 'کلید متغیر: {{key}} طولانی است. نمی‌تواند بیش از 30 کاراکتر باشد', - notValid: 'کلید متغیر: {{key}} نامعتبر است. فقط می‌تواند شامل حروف، اعداد و زیرخط باشد', - notStartWithNumber: 'کلید متغیر: {{key}} نمی‌تواند با عدد شروع شود', - keyAlreadyExists: 'کلید متغیر: :{{key}} از قبل وجود دارد', + canNoBeEmpty: '{{key}} مطلوب', + tooLong: '{{key}} طولانی است. نمی‌تواند بیش از 30 کاراکتر باشد', + notValid: '{{key}} نامعتبر است. فقط می‌تواند شامل حروف، اعداد و زیرخط باشد', + notStartWithNumber: '{{key}} نمی‌تواند با عدد شروع شود', + keyAlreadyExists: '{{key}} از قبل وجود دارد', }, otherError: { promptNoBeEmpty: 'پرس و جو نمی‌تواند خالی باشد', diff --git a/web/i18n/fr-FR/app-debug.ts b/web/i18n/fr-FR/app-debug.ts index 91d2dcb142..b71d251956 100644 --- a/web/i18n/fr-FR/app-debug.ts +++ b/web/i18n/fr-FR/app-debug.ts @@ -237,11 +237,11 @@ const translation = { typeSelect: 'Sélectionner', }, varKeyError: { - canNoBeEmpty: 'La clé variable ne peut pas être vide', - tooLong: 'Variable key: {{key}} too length. Can not be longer then 30 characters', - notValid: 'Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores', - notStartWithNumber: 'Variable key: {{key}} can not start with a number', - keyAlreadyExists: 'Variable key: :{{key}} already exists', + canNoBeEmpty: '{{key}} est obligatoire', + tooLong: '{{key}} too length. Can not be longer then 30 characters', + notValid: '{{key}} is invalid. Can only contain letters, numbers, and underscores', + notStartWithNumber: '{{key}} can not start with a number', + keyAlreadyExists: '{{key}} already exists', }, otherError: { promptNoBeEmpty: 'Le prompt ne peut pas être vide', diff --git a/web/i18n/hi-IN/app-debug.ts b/web/i18n/hi-IN/app-debug.ts index 872a1b9fe8..29944c8d84 100644 --- a/web/i18n/hi-IN/app-debug.ts +++ b/web/i18n/hi-IN/app-debug.ts @@ -276,14 +276,14 @@ const translation = { typeSelect: 'चुनें', }, varKeyError: { - canNoBeEmpty: 'वेरिएबल कुंजी खाली नहीं हो सकती', + canNoBeEmpty: '{{key}} आवश्यक है', tooLong: - 'वेरिएबल कुंजी: {{key}} बहुत लंबी है। 30 वर्णों से अधिक नहीं हो सकती', + '{{key}} बहुत लंबी है। 30 वर्णों से अधिक नहीं हो सकती', notValid: - 'वेरिएबल कुंजी: {{key}} अवैध है। केवल अक्षर, संख्याएं, और अंडरस्कोर शामिल हो सकते हैं', + '{{key}} अवैध है। केवल अक्षर, संख्याएं, और अंडरस्कोर शामिल हो सकते हैं', notStartWithNumber: - 'वेरिएबल कुंजी: {{key}} एक संख्या से प्रारंभ नहीं हो सकती', - keyAlreadyExists: 'वेरिएबल कुंजी: {{key}} पहले से मौजूद है', + '{{key}} एक संख्या से प्रारंभ नहीं हो सकती', + keyAlreadyExists: '{{key}} पहले से मौजूद है', }, otherError: { promptNoBeEmpty: 'प्रॉम्प्ट खाली नहीं हो सकता', diff --git a/web/i18n/it-IT/app-debug.ts b/web/i18n/it-IT/app-debug.ts index 8efe575945..a4cf7bba2d 100644 --- a/web/i18n/it-IT/app-debug.ts +++ b/web/i18n/it-IT/app-debug.ts @@ -278,14 +278,14 @@ const translation = { typeSelect: 'Seleziona', }, varKeyError: { - canNoBeEmpty: 'La chiave della variabile non può essere vuota', + canNoBeEmpty: '{{key}} è obbligatorio', tooLong: - 'La chiave della variabile: {{key}} è troppo lunga. Non può essere più lunga di 30 caratteri', + '{{key}} è troppo lunga. Non può essere più lunga di 30 caratteri', notValid: - 'La chiave della variabile: {{key}} non è valida. Può contenere solo lettere, numeri e underscore', + '{{key}} non è valida. Può contenere solo lettere, numeri e underscore', notStartWithNumber: - 'La chiave della variabile: {{key}} non può iniziare con un numero', - keyAlreadyExists: 'La chiave della variabile: {{key}} esiste già', + '{{key}} non può iniziare con un numero', + keyAlreadyExists: '{{key}} esiste già', }, otherError: { promptNoBeEmpty: 'Il prompt non può essere vuoto', diff --git a/web/i18n/ja-JP/app-debug.ts b/web/i18n/ja-JP/app-debug.ts index 39ea3386ed..6049be2406 100644 --- a/web/i18n/ja-JP/app-debug.ts +++ b/web/i18n/ja-JP/app-debug.ts @@ -284,11 +284,11 @@ const translation = { typeSelect: '選択', }, varKeyError: { - canNoBeEmpty: '変数キーを空にすることはできません', - tooLong: '変数キー: {{key}} が長すぎます。30文字を超えることはできません', - notValid: '変数キー: {{key}} が無効です。文字、数字、アンダースコアのみを含めることができます', - notStartWithNumber: '変数キー: {{key}} は数字で始めることはできません', - keyAlreadyExists: '変数キー: {{key}} はすでに存在します', + canNoBeEmpty: '{{key}} は必須です', + tooLong: '{{key}} が長すぎます。30文字を超えることはできません', + notValid: '{{key}} が無効です。文字、数字、アンダースコアのみを含めることができます', + notStartWithNumber: '{{key}} は数字で始めることはできません', + keyAlreadyExists: '{{key}} はすでに存在します', }, otherError: { promptNoBeEmpty: 'プロンプトを空にすることはできません', diff --git a/web/i18n/ko-KR/app-debug.ts b/web/i18n/ko-KR/app-debug.ts index 77eac2503d..0a2488b64c 100644 --- a/web/i18n/ko-KR/app-debug.ts +++ b/web/i18n/ko-KR/app-debug.ts @@ -248,11 +248,11 @@ const translation = { typeSelect: '선택', }, varKeyError: { - canNoBeEmpty: '변수 키를 비울 수 없습니다', - tooLong: '변수 키: {{key}}가 너무 깁니다. 30자를 넘을 수 없습니다', - notValid: '변수 키: {{key}}가 유효하지 않습니다. 문자, 숫자, 밑줄만 포함할 수 있습니다', - notStartWithNumber: '변수 키: {{key}}는 숫자로 시작할 수 없습니다', - keyAlreadyExists: '변수 키: {{key}}는 이미 존재합니다', + canNoBeEmpty: '{{key}}가 필요합니다', + tooLong: '{{key}}가 너무 깁니다. 30자를 넘을 수 없습니다', + notValid: '{{key}}가 유효하지 않습니다. 문자, 숫자, 밑줄만 포함할 수 있습니다', + notStartWithNumber: '{{key}}는 숫자로 시작할 수 없습니다', + keyAlreadyExists: '{{key}}는 이미 존재합니다', }, otherError: { promptNoBeEmpty: '프롬프트를 비울 수 없습니다', diff --git a/web/i18n/pl-PL/app-debug.ts b/web/i18n/pl-PL/app-debug.ts index 960209c045..afb412f264 100644 --- a/web/i18n/pl-PL/app-debug.ts +++ b/web/i18n/pl-PL/app-debug.ts @@ -275,14 +275,14 @@ const translation = { typeSelect: 'Wybierz', }, varKeyError: { - canNoBeEmpty: 'Klucz zmiennej nie może być pusty', + canNoBeEmpty: '{{klucz}} jest wymagany', tooLong: - 'Klucz zmiennej: {{key}} za długi. Nie może być dłuższy niż 30 znaków', + '{{key}} za długi. Nie może być dłuższy niż 30 znaków', notValid: - 'Klucz zmiennej: {{key}} jest nieprawidłowy. Może zawierać tylko litery, cyfry i podkreślenia', + '{{key}} jest nieprawidłowy. Może zawierać tylko litery, cyfry i podkreślenia', notStartWithNumber: - 'Klucz zmiennej: {{key}} nie może zaczynać się od cyfry', - keyAlreadyExists: 'Klucz zmiennej: :{{key}} już istnieje', + '{{key}} nie może zaczynać się od cyfry', + keyAlreadyExists: '{{key}} już istnieje', }, otherError: { promptNoBeEmpty: 'Monit nie może być pusty', diff --git a/web/i18n/pt-BR/app-debug.ts b/web/i18n/pt-BR/app-debug.ts index 91730d44b3..9605bd5d95 100644 --- a/web/i18n/pt-BR/app-debug.ts +++ b/web/i18n/pt-BR/app-debug.ts @@ -254,11 +254,11 @@ const translation = { typeSelect: 'Selecionar', }, varKeyError: { - canNoBeEmpty: 'A chave da variável não pode estar vazia', - tooLong: 'A chave da variável: {{key}} é muito longa. Não pode ter mais de 30 caracteres', - notValid: 'A chave da variável: {{key}} é inválida. Pode conter apenas letras, números e sublinhados', - notStartWithNumber: 'A chave da variável: {{key}} não pode começar com um número', - keyAlreadyExists: 'A chave da variável: :{{key}} já existe', + canNoBeEmpty: '{{key}} é obrigatório', + tooLong: '{{key}} é muito longa. Não pode ter mais de 30 caracteres', + notValid: '{{key}} é inválida. Pode conter apenas letras, números e sublinhados', + notStartWithNumber: '{{key}} não pode começar com um número', + keyAlreadyExists: '{{key}} já existe', }, otherError: { promptNoBeEmpty: 'A solicitação não pode estar vazia', diff --git a/web/i18n/ro-RO/app-debug.ts b/web/i18n/ro-RO/app-debug.ts index b4e9442de8..7363f2954f 100644 --- a/web/i18n/ro-RO/app-debug.ts +++ b/web/i18n/ro-RO/app-debug.ts @@ -254,11 +254,11 @@ const translation = { typeSelect: 'Selectează', }, varKeyError: { - canNoBeEmpty: 'Cheia variabilei nu poate fi goală', - tooLong: 'Cheia variabilei: {{key}} este prea lungă. Nu poate fi mai lungă de 30 de caractere', - notValid: 'Cheia variabilei: {{key}} este nevalidă. Poate conține doar litere, cifre și sublinieri', - notStartWithNumber: 'Cheia variabilei: {{key}} nu poate începe cu un număr', - keyAlreadyExists: 'Cheia variabilei: :{{key}} deja există', + canNoBeEmpty: '{{key}} este necesară', + tooLong: '{{key}} este prea lungă. Nu poate fi mai lungă de 30 de caractere', + notValid: '{{key}} este nevalidă. Poate conține doar litere, cifre și sublinieri', + notStartWithNumber: '{{key}} nu poate începe cu un număr', + keyAlreadyExists: ':{{key}} deja există', }, otherError: { promptNoBeEmpty: 'Promptul nu poate fi gol', diff --git a/web/i18n/tr-TR/app-debug.ts b/web/i18n/tr-TR/app-debug.ts index 889f48c78e..fbf51535fe 100644 --- a/web/i18n/tr-TR/app-debug.ts +++ b/web/i18n/tr-TR/app-debug.ts @@ -290,11 +290,11 @@ const translation = { typeSelect: 'Seçim', }, varKeyError: { - canNoBeEmpty: 'Değişken anahtarı boş olamaz', - tooLong: 'Değişken anahtarı: {{key}} çok uzun. 30 karakterden uzun olamaz', - notValid: 'Değişken anahtarı: {{key}} geçersizdir. Sadece harfler, rakamlar ve altçizgiler içerebilir', - notStartWithNumber: 'Değişken anahtarı: {{key}} bir rakamla başlamamalıdır', - keyAlreadyExists: 'Değişken anahtarı: {{key}} zaten mevcut', + canNoBeEmpty: '{{key}} gereklidir', + tooLong: '{{key}} çok uzun. 30 karakterden uzun olamaz', + notValid: '{{key}} geçersizdir. Sadece harfler, rakamlar ve altçizgiler içerebilir', + notStartWithNumber: '{{key}} bir rakamla başlamamalıdır', + keyAlreadyExists: '{{key}} zaten mevcut', }, otherError: { promptNoBeEmpty: 'Prompt boş olamaz', diff --git a/web/i18n/uk-UA/app-debug.ts b/web/i18n/uk-UA/app-debug.ts index c64444c871..7c0ba45b3c 100644 --- a/web/i18n/uk-UA/app-debug.ts +++ b/web/i18n/uk-UA/app-debug.ts @@ -248,11 +248,11 @@ const translation = { typeSelect: 'Вибрати', // Select }, varKeyError: { - canNoBeEmpty: 'Ключ змінної не може бути порожнім', // Variable key can not be empty - tooLong: 'Ключ змінної: {{key}} занадто довгий. Не може бути більше 30 символів', // Variable key: {{key}} too length. Can not be longer then 30 characters - notValid: 'Ключ змінної: {{key}} недійсний. Може містити лише літери, цифри та підкреслення', // Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores - notStartWithNumber: 'Ключ змінної: {{key}} не може починатися з цифри', // Variable key: {{key}} can not start with a number - keyAlreadyExists: 'Ключ змінної: :{{key}} вже існує', // Variable key: :{{key}} already exists + canNoBeEmpty: 'Потрібен {{key}}', // Variable key can not be empty + tooLong: '{{key}} занадто довгий. Не може бути більше 30 символів', // Variable key: {{key}} too length. Can not be longer then 30 characters + notValid: '{{key}} недійсний. Може містити лише літери, цифри та підкреслення', // Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores + notStartWithNumber: '{{key}} не може починатися з цифри', // Variable key: {{key}} can not start with a number + keyAlreadyExists: ':{{key}} вже існує', // Variable key: :{{key}} already exists }, otherError: { promptNoBeEmpty: 'Команда не може бути порожньою', // Prompt can not be empty diff --git a/web/i18n/vi-VN/app-debug.ts b/web/i18n/vi-VN/app-debug.ts index 4797f768e3..906b39d10a 100644 --- a/web/i18n/vi-VN/app-debug.ts +++ b/web/i18n/vi-VN/app-debug.ts @@ -248,11 +248,11 @@ const translation = { typeSelect: 'Lựa chọn', }, varKeyError: { - canNoBeEmpty: 'Khóa biến không thể trống', - tooLong: 'Khóa biến: {{key}} quá dài. Không thể dài hơn 30 ký tự', - notValid: 'Khóa biến: {{key}} không hợp lệ. Chỉ có thể chứa chữ cái, số, và dấu gạch dưới', - notStartWithNumber: 'Khóa biến: {{key}} không thể bắt đầu bằng số', - keyAlreadyExists: 'Khóa biến: {{key}} đã tồn tại', + canNoBeEmpty: '{{key}} là bắt buộc', + tooLong: '{{key}} quá dài. Không thể dài hơn 30 ký tự', + notValid: '{{key}} không hợp lệ. Chỉ có thể chứa chữ cái, số, và dấu gạch dưới', + notStartWithNumber: '{{key}} không thể bắt đầu bằng số', + keyAlreadyExists: '{{key}} đã tồn tại', }, otherError: { promptNoBeEmpty: 'Lời nhắc không thể trống', diff --git a/web/i18n/zh-Hans/app-debug.ts b/web/i18n/zh-Hans/app-debug.ts index b95bb8ce51..febf80d786 100644 --- a/web/i18n/zh-Hans/app-debug.ts +++ b/web/i18n/zh-Hans/app-debug.ts @@ -287,11 +287,11 @@ const translation = { typeSelect: '下拉选项', }, varKeyError: { - canNoBeEmpty: '变量不能为空', - tooLong: '变量: {{key}} 长度太长。不能超过 30 个字符', - notValid: '变量: {{key}} 非法。只能包含英文字符,数字和下划线', - notStartWithNumber: '变量: {{key}} 不能以数字开头', - keyAlreadyExists: '变量:{{key}} 已存在', + canNoBeEmpty: '{{key}}必填', + tooLong: '{{key}} 长度太长。不能超过 30 个字符', + notValid: '{{key}} 非法。只能包含英文字符,数字和下划线', + notStartWithNumber: '{{key}} 不能以数字开头', + keyAlreadyExists: '{{key}} 已存在', }, otherError: { promptNoBeEmpty: '提示词不能为空', @@ -320,7 +320,6 @@ const translation = { 'required': '必填', 'content': '内容', 'errorMsg': { - varNameRequired: '变量名称必填', labelNameRequired: '显示名称必填', varNameCanBeRepeat: '变量名称不能重复', atLeastOneOption: '至少需要一个选项', diff --git a/web/i18n/zh-Hant/app-debug.ts b/web/i18n/zh-Hant/app-debug.ts index ec5d536bb9..ca4dfbb0cf 100644 --- a/web/i18n/zh-Hant/app-debug.ts +++ b/web/i18n/zh-Hant/app-debug.ts @@ -233,11 +233,11 @@ const translation = { typeSelect: '下拉選項', }, varKeyError: { - canNoBeEmpty: '變數不能為空', - tooLong: '變數: {{key}} 長度太長。不能超過 30 個字元', - notValid: '變數: {{key}} 非法。只能包含英文字元,數字和下劃線', - notStartWithNumber: '變數: {{key}} 不能以數字開頭', - keyAlreadyExists: '變數:{{key}} 已存在', + canNoBeEmpty: '{{key}} 是必要的', + tooLong: '{{key}} 長度太長。不能超過 30 個字元', + notValid: '{{key}} 非法。只能包含英文字元,數字和下劃線', + notStartWithNumber: '{{key}} 不能以數字開頭', + keyAlreadyExists: '{{key}} 已存在', }, otherError: { promptNoBeEmpty: '提示詞不能為空', From fb75bd979048ed410616578dc9e825734174240a Mon Sep 17 00:00:00 2001 From: Joel Date: Fri, 23 Aug 2024 15:08:34 +0800 Subject: [PATCH 10/24] chore: improve the check time of variable name in conversation and env var (#7572) --- .../components/variable-modal.tsx | 26 +++++++++++-------- .../panel/env-panel/variable-modal.tsx | 26 +++++++++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 289e29d592..e6c1ebb5cc 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -15,6 +15,7 @@ import type { ConversationVariable } from '@/app/components/workflow/types' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' import cn from '@/utils/classnames' +import { checkKeys } from '@/utils/var' export type ModalPropsType = { chatVar?: ConversationVariable @@ -128,14 +129,16 @@ const ChatVariableModal = ({ } } - const handleNameChange = (v: string) => { - if (!v) - return setName('') - if (!/^[a-zA-Z0-9_]+$/.test(v)) - return notify({ type: 'error', message: 'name is can only contain letters, numbers and underscores' }) - if (/^[0-9]/.test(v)) - return notify({ type: 'error', message: 'name can not start with a number' }) - setName(v) + const checkVariableName = (value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('workflow.env.modal.name') }), + }) + return false + } + return true } const handleTypeChange = (v: ChatVarType) => { @@ -211,8 +214,8 @@ const ChatVariableModal = ({ } const handleSave = () => { - if (!name) - return notify({ type: 'error', message: 'name can not be empty' }) + if (!checkVariableName(name)) + return if (!chatVar && varList.some(chatVar => chatVar.name === name)) return notify({ type: 'error', message: 'name is existed' }) // if (type !== ChatVarType.Object && !value) @@ -272,7 +275,8 @@ const ChatVariableModal = ({ className='block px-3 w-full h-8 bg-components-input-bg-normal system-sm-regular radius-md border border-transparent appearance-none outline-none caret-primary-600 hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:system-sm-regular placeholder:text-components-input-text-placeholder' placeholder={t('workflow.chatVariable.modal.namePlaceholder') || ''} value={name} - onChange={e => handleNameChange(e.target.value)} + onChange={e => setName(e.target.value || '')} + onBlur={e => checkVariableName(e.target.value)} type='text' />
diff --git a/web/app/components/workflow/panel/env-panel/variable-modal.tsx b/web/app/components/workflow/panel/env-panel/variable-modal.tsx index 46f92bd8ed..c62a849f36 100644 --- a/web/app/components/workflow/panel/env-panel/variable-modal.tsx +++ b/web/app/components/workflow/panel/env-panel/variable-modal.tsx @@ -9,6 +9,7 @@ import { ToastContext } from '@/app/components/base/toast' import { useStore } from '@/app/components/workflow/store' import type { EnvironmentVariable } from '@/app/components/workflow/types' import cn from '@/utils/classnames' +import { checkKeys } from '@/utils/var' export type ModalPropsType = { env?: EnvironmentVariable @@ -28,19 +29,21 @@ const VariableModal = ({ const [name, setName] = React.useState('') const [value, setValue] = React.useState() - const handleNameChange = (v: string) => { - if (!v) - return setName('') - if (!/^[a-zA-Z0-9_]+$/.test(v)) - return notify({ type: 'error', message: 'name is can only contain letters, numbers and underscores' }) - if (/^[0-9]/.test(v)) - return notify({ type: 'error', message: 'name can not start with a number' }) - setName(v) + const checkVariableName = (value: string) => { + const { isValid, errorMessageKey } = checkKeys([value], false) + if (!isValid) { + notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('workflow.env.modal.name') }), + }) + return false + } + return true } const handleSave = () => { - if (!name) - return notify({ type: 'error', message: 'name can not be empty' }) + if (!checkVariableName(name)) + return if (!value) return notify({ type: 'error', message: 'value can not be empty' }) if (!env && envList.some(env => env.name === name)) @@ -118,7 +121,8 @@ const VariableModal = ({ className='block px-3 w-full h-8 bg-components-input-bg-normal system-sm-regular radius-md border border-transparent appearance-none outline-none caret-primary-600 hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs placeholder:system-sm-regular placeholder:text-components-input-text-placeholder' placeholder={t('workflow.env.modal.namePlaceholder') || ''} value={name} - onChange={e => handleNameChange(e.target.value)} + onChange={e => setName(e.target.value || '')} + onBlur={e => checkVariableName(e.target.value)} type='text' />
From 0c38a8fdd43dd24f5daa5763a9a6f4e2acf823ca Mon Sep 17 00:00:00 2001 From: Nam Vu Date: Fri, 23 Aug 2024 14:25:07 +0700 Subject: [PATCH 11/24] Fix: voice language (#7570) --- web/i18n/de-DE/common.ts | 4 ++++ web/i18n/en-US/common.ts | 6 +++++- web/i18n/es-ES/common.ts | 4 ++++ web/i18n/fa-IR/common.ts | 4 ++++ web/i18n/fr-FR/common.ts | 6 +++++- web/i18n/hi-IN/common.ts | 4 ++++ web/i18n/it-IT/common.ts | 4 ++++ web/i18n/ja-JP/common.ts | 4 ++++ web/i18n/ko-KR/common.ts | 4 ++++ web/i18n/pl-PL/common.ts | 4 ++++ web/i18n/pt-BR/common.ts | 4 ++++ web/i18n/ro-RO/common.ts | 4 ++++ web/i18n/tr-TR/common.ts | 6 +++++- web/i18n/uk-UA/common.ts | 4 ++++ web/i18n/vi-VN/common.ts | 5 +++++ web/i18n/zh-Hans/common.ts | 4 ++++ web/i18n/zh-Hant/common.ts | 4 ++++ web/yarn.lock | 2 +- 18 files changed, 73 insertions(+), 4 deletions(-) diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index a7c7bb58fd..9a66d5b175 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: 'Ukrainisch', viVN: 'Vietnamesisch', plPL: 'Polnisch', + roRO: 'Rumänisch', + hiIN: 'Hindi', + trTR: 'Türkisch', + faIR: 'Persisch', }, }, unit: { diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index d48a9c233e..87dab5cb71 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -55,7 +55,7 @@ const translation = { frFR: 'French', esES: 'Spanish', itIT: 'Italian', - thTH: 'Thai.', + thTH: 'Thai', idID: 'Indonesian', jaJP: 'Japanese', koKR: 'Korean', @@ -64,6 +64,10 @@ const translation = { ukUA: 'Ukrainian', viVN: 'Vietnamese', plPL: 'Polish', + roRO: 'Romanian', + hiIN: 'Hindi', + trTR: 'Türkçe', + faIR: 'Farsi', }, }, unit: { diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index 4afc06b098..fc37775263 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: 'Ucraniano', viVN: 'Vietnamita', plPL: 'Polaco', + roRO: 'Rumano', + hiIN: 'Hindi', + trTR: 'Turco', + faIR: 'Persa', }, }, unit: { diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index 2b1c4647db..e4417bcbcc 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: 'اوکراینی', viVN: 'ویتنامی', plPL: 'لهستانی', + roRO: 'رومانیایی', + hiIN: 'هندی', + trTR: 'ترکی', + faIR: 'فارسی', }, }, unit: { diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index 424d8e4a99..2ae0731006 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -51,7 +51,7 @@ const translation = { frFR: 'Français', esES: 'Espagnol', itIT: 'Italien', - thTH: 'Thaï.', + thTH: 'Thaï', idID: 'Indonésien', jaJP: 'Japonais', koKR: 'Coréen', @@ -60,6 +60,10 @@ const translation = { ukUA: 'Ukrainien', viVN: 'Vietnamien', plPL: 'Polonais', + roRO: 'Roumain', + hiIN: 'Hindi', + trTR: 'Turc', + faIR: 'Persan', }, }, unit: { diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index 2be7cddf8f..0a210072e1 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: 'यूक्रेनी', viVN: 'वियतनामी', plPL: 'पोलिश', + roRO: 'रोमानियाई', + hiIN: 'हिन्दी', + trTR: 'तुर्की', + faIR: 'फ़ारसी', }, }, unit: { diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index cc9c34e2dc..595a5075eb 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: 'Ucraino', viVN: 'Vietnamita', plPL: 'Polacco', + roRO: 'Rumeno', + hiIN: 'Hindi', + trTR: 'Turco', + faIR: 'Persiano', }, }, unit: { diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index f3fb8466f1..fc61141bd3 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: 'ウクライナ語', viVN: 'ベトナム語', plPL: 'ポーランド語', + roRO: 'ルーマニア語', + hiIN: 'ヒンディー語', + trTR: 'トルコ語', + faIR: 'ペルシア語', }, }, unit: { diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index 9e78510078..edd0295b89 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: '우크라이나어', viVN: '베트남어', plPL: '폴란드어', + roRO: '루마니아어', + hiIN: '힌디어', + trTR: '터키어', + faIR: '페르시아어', }, }, unit: { diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index 39572ce09b..1f41abe154 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: 'Ukraiński', viVN: 'Wietnamski', plPL: 'Polski', + roRO: 'Rumuński', + hiIN: 'Hindi', + trTR: 'Turecki', + faIR: 'Perski', }, }, unit: { diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index 1b29d06669..f93979404b 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: 'Ucraniano', viVN: 'Vietnamita', plPL: 'Polonês', + roRO: 'Romeno', + hiIN: 'Hindi', + trTR: 'Turco', + faIR: 'Persa', }, }, unit: { diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index e7037f65b8..34ca1c4671 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -59,6 +59,10 @@ const translation = { ruRU: 'Rusă', ukUA: 'Ucraineană', viVN: 'Vietnameză', + roRO: 'Română', + hiIN: 'Hindi', + trTR: 'Turcă', + faIR: 'Persană', }, }, unit: { diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index f7981c3f48..a194ffd769 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -55,7 +55,7 @@ const translation = { frFR: 'French', esES: 'Spanish', itIT: 'Italian', - thTH: 'Thai.', + thTH: 'Thai', idID: 'Indonesian', jaJP: 'Japanese', koKR: 'Korean', @@ -64,6 +64,10 @@ const translation = { ukUA: 'Ukrainian', viVN: 'Vietnamese', plPL: 'Polish', + roRO: 'Romence', + hiIN: 'Hintçe', + trTR: 'Türkçe', + faIR: 'Farsça', }, }, unit: { diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index fb0003d35f..33324ce0f2 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: 'Українська', viVN: 'В\'є тнамська', plPL: 'Польська', + roRO: 'Румунська', + hiIN: 'Хінді', + trTR: 'Турецька', + faIR: 'Перська', }, }, unit: { diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index 232148ce74..19855d31f0 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -45,6 +45,7 @@ const translation = { voice: { language: { zhHans: 'Tiếng Trung', + zhHant: 'Tiếng Trung phồn thể', enUS: 'Tiếng Anh', deDE: 'Tiếng Đức', frFR: 'Tiếng Pháp', @@ -59,6 +60,10 @@ const translation = { ukUA: 'Tiếng Ukraina', viVN: 'Tiếng Việt', plPL: 'Tiếng Ba Lan', + roRO: 'Tiếng Rumani', + hiIN: 'Tiếng Hindi', + trTR: 'Tiếng Thổ Nhĩ Kỳ', + faIR: 'Tiếng Ba Tư', }, }, unit: { diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index e0072e2cba..73cf435a7c 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -64,6 +64,10 @@ const translation = { ukUA: '乌克兰语', viVN: '越南语', plPL: '波兰语', + roRO: '罗马尼亚语', + hiIN: '印地语', + trTR: '土耳其语', + faIR: '波斯语', }, }, unit: { diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index f4d6952f76..e14c3c3196 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -60,6 +60,10 @@ const translation = { ukUA: '烏克蘭語', viVN: '越南語', plPL: '波蘭語', + roRO: '羅馬尼亞語', + hiIN: '印地語', + trTR: '土耳其語', + faIR: '波斯語', }, }, unit: { diff --git a/web/yarn.lock b/web/yarn.lock index f6a6694b51..d50aa33f3e 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -9373,4 +9373,4 @@ zustand@^4.4.1, zustand@^4.5.2: zwitch@^2.0.0: version "2.0.4" resolved "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz" - integrity sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A== + integrity sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A== \ No newline at end of file From 60250029712f1ae864eb5fe57bbbe4488dd32acf Mon Sep 17 00:00:00 2001 From: Fei He Date: Fri, 23 Aug 2024 15:32:38 +0800 Subject: [PATCH 12/24] add qwen text-embedding-v3 support. (#7567) --- .../tongyi/text_embedding/text-embedding-v3.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml new file mode 100644 index 0000000000..171a379ee2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml @@ -0,0 +1,9 @@ +model: text-embedding-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB From 9864b354657910ea05812055658e232c29f1200e Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Fri, 23 Aug 2024 15:53:49 +0800 Subject: [PATCH 13/24] langfuser add view button (#7571) --- api/core/ops/entities/config_entity.py | 1 + api/core/ops/langfuse_trace/langfuse_trace.py | 8 +++++ api/core/ops/ops_trace_manager.py | 16 ++++++++-- api/services/ops_service.py | 11 ++++++- .../overview/tracing/provider-panel.tsx | 31 +++++++++++++++---- web/i18n/en-US/app.ts | 1 + web/i18n/zh-Hans/app.ts | 1 + web/i18n/zh-Hant/app.ts | 1 + 8 files changed, 61 insertions(+), 9 deletions(-) diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab..447f668e26 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -21,6 +21,7 @@ class LangfuseConfig(BaseTracingConfig): """ public_key: str secret_key: str + project_key: str host: str = 'https://api.langfuse.com' @field_validator("host") diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 698398e0cb..a21c67ed50 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -419,3 +419,11 @@ class LangFuseDataTrace(BaseTraceInstance): except Exception as e: logger.debug(f"LangFuse API check failed: {str(e)}") raise ValueError(f"LangFuse API check failed: {str(e)}") + + def get_project_key(self): + try: + projects = self.langfuse_client.client.projects.get() + return projects.data[0].id + except Exception as e: + logger.debug(f"LangFuse get project key failed: {str(e)}") + raise ValueError(f"LangFuse get project key failed: {str(e)}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 068b490ec8..1416d6bd2d 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -38,7 +38,7 @@ provider_config_map = { TracingProviderEnum.LANGFUSE.value: { 'config_class': LangfuseConfig, 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host'], + 'other_keys': ['host', 'project_key'], 'trace_instance': LangFuseDataTrace }, TracingProviderEnum.LANGSMITH.value: { @@ -123,7 +123,6 @@ class OpsTraceManager: for key in other_keys: new_config[key] = decrypt_tracing_config.get(key, "") - return config_class(**new_config).model_dump() @classmethod @@ -252,6 +251,19 @@ class OpsTraceManager: tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() + @staticmethod + def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ + provider_config_map[tracing_provider]['trace_instance'] + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_key() + class TraceTask: def __init__( diff --git a/api/services/ops_service.py b/api/services/ops_service.py index ffc12a9acd..7b2edcf7cb 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -22,6 +22,10 @@ class OpsService: # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) + if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')): + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + decrypt_tracing_config['project_key'] = project_key + decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) trace_config_data.tracing_config = decrypt_tracing_config @@ -37,7 +41,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider != None: + if tracing_provider not in provider_config_map.keys() and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ @@ -51,6 +55,9 @@ class OpsService: if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} + # get project key + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + # check if trace config already exists trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider @@ -62,6 +69,8 @@ class OpsService: # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if tracing_provider == 'langfuse': + tracing_config['project_key'] = project_key trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 120fe29dff..4a39586064 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,6 +6,7 @@ import { TracingProvider } from './type' import cn from '@/utils/classnames' import { LangfuseIconBig, LangsmithIconBig } from '@/app/components/base/icons/src/public/tracing' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' +import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' const I18N_PREFIX = 'app.tracing' @@ -13,6 +14,7 @@ type Props = { type: TracingProvider readOnly: boolean isChosen: boolean + Config: any onChoose: () => void hasConfigured: boolean onConfig: () => void @@ -29,6 +31,7 @@ const ProviderPanel: FC = ({ type, readOnly, isChosen, + Config, onChoose, hasConfigured, onConfig, @@ -41,6 +44,14 @@ const ProviderPanel: FC = ({ onConfig() }, [onConfig]) + const viewBtnClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const url = `${Config?.host}/project/${Config?.project_key}` + window.open(url, '_blank', 'noopener,noreferrer') + }, [Config?.host, Config?.project_key]) + const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() if (isChosen || !hasConfigured || readOnly) @@ -58,12 +69,20 @@ const ProviderPanel: FC = ({ {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
} {!readOnly && ( -
- -
{t(`${I18N_PREFIX}.config`)}
+
+ {hasConfigured && ( +
+ +
{t(`${I18N_PREFIX}.view`)}
+
+ )} +
+ +
{t(`${I18N_PREFIX}.config`)}
+
)} diff --git a/web/i18n/en-US/app.ts b/web/i18n/en-US/app.ts index 39b47c8eb4..90724098de 100644 --- a/web/i18n/en-US/app.ts +++ b/web/i18n/en-US/app.ts @@ -95,6 +95,7 @@ const translation = { title: 'Tracing app performance', description: 'Configuring a Third-Party LLMOps provider and tracing app performance.', config: 'Config', + view: 'View', collapse: 'Collapse', expand: 'Expand', tracing: 'Tracing', diff --git a/web/i18n/zh-Hans/app.ts b/web/i18n/zh-Hans/app.ts index 6703e1ca95..e12ed1b35d 100644 --- a/web/i18n/zh-Hans/app.ts +++ b/web/i18n/zh-Hans/app.ts @@ -94,6 +94,7 @@ const translation = { title: '追踪应用性能', description: '配置第三方 LLMOps 提供商并跟踪应用程序性能。', config: '配置', + view: '查看', collapse: '折叠', expand: '展开', tracing: '追踪', diff --git a/web/i18n/zh-Hant/app.ts b/web/i18n/zh-Hant/app.ts index 4b915b7f2d..ff162c5b61 100644 --- a/web/i18n/zh-Hant/app.ts +++ b/web/i18n/zh-Hant/app.ts @@ -90,6 +90,7 @@ const translation = { title: '追蹤應用程式效能', description: '配置第三方LLMOps提供商並追蹤應用程式效能。', config: '配置', + view: '查看', collapse: '收起', expand: '展開', tracing: '追蹤', From df69ad9f0e4336f7cab62c6fdb4f53f03a290944 Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Fri, 23 Aug 2024 16:23:26 +0800 Subject: [PATCH 14/24] Langfuse view button (#7578) --- .../[appId]/overview/tracing/config-popup.tsx | 2 ++ .../[appId]/overview/tracing/provider-panel.tsx | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 7aa1fca96d..7744bab7ba 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -85,6 +85,7 @@ const ConfigPopup: FC = ({ = ({ void hasConfigured: boolean onConfig: () => void @@ -31,7 +31,7 @@ const ProviderPanel: FC = ({ type, readOnly, isChosen, - Config, + config, onChoose, hasConfigured, onConfig, @@ -48,9 +48,9 @@ const ProviderPanel: FC = ({ e.preventDefault() e.stopPropagation() - const url = `${Config?.host}/project/${Config?.project_key}` + const url = `${config?.host}/project/${config?.project_key}` window.open(url, '_blank', 'noopener,noreferrer') - }, [Config?.host, Config?.project_key]) + }, []) const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() From ad130110430b4a96a5a63f508fe98b0213e4a200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=99=A2=E5=93=8E=E5=93=9F=E5=96=82?= Date: Fri, 23 Aug 2024 16:24:45 +0800 Subject: [PATCH 15/24] add JSON Mode support for moonshot models (#7568) --- .../moonshot/llm/moonshot-v1-128k.yaml | 12 ++++++++++++ .../moonshot/llm/moonshot-v1-32k.yaml | 12 ++++++++++++ .../model_providers/moonshot/llm/moonshot-v1-8k.yaml | 12 ++++++++++++ 3 files changed, 36 insertions(+) diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 0d2e51c47f..1078e84c59 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 128000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.06' output: '0.06' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 9ff537014a..9c739d0501 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 32000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.024' output: '0.024' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index 0f308d3676..187a86999e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.012' output: '0.012' From f29685f8a122d5d7c8ce124d86abc797a5458209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=9A=AE=E7=9A=AE?= <363230482@qq.com> Date: Fri, 23 Aug 2024 16:59:34 +0800 Subject: [PATCH 16/24] fix score_threshold is none, return all top K documents (#7581) --- api/core/rag/retrieval/dataset_retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index e945364796..fc6d231f8e 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -614,7 +614,7 @@ class DatasetRetrieval: top_k: int, score_threshold: float) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold and document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata['score'] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] From 25386af41aa961f332258190fad63a1639adb8ac Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Fri, 23 Aug 2024 17:20:19 +0800 Subject: [PATCH 17/24] fix: knowledge setting "knowledge name" input width (#7584) --- .../components/datasets/settings/form/index.tsx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index f7519248e6..404a8ed6a0 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -163,12 +163,14 @@ const Form = () => {
{t('datasetSettings.form.name')}
- setName(e.target.value)} - /> +
+ setName(e.target.value)} + /> +
From e42848f4b7071d0ca188c1eb03b19e009bd838a8 Mon Sep 17 00:00:00 2001 From: Amos Date: Fri, 23 Aug 2024 11:50:38 +0100 Subject: [PATCH 18/24] Do not pass query parameter when the value is empty (#7585) --- api/core/tools/tool/api_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 69e3dfa061..38f10032e2 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -144,7 +144,7 @@ class ApiTool(Tool): path_params[parameter['name']] = value elif parameter['in'] == 'query': - params[parameter['name']] = value + if value !='': params[parameter['name']] = value elif parameter['in'] == 'cookie': cookies[parameter['name']] = value From 70d6ab0bf5c68d44f1a62dc3e4e744f6a31fe0e3 Mon Sep 17 00:00:00 2001 From: "Jie.F" Date: Fri, 23 Aug 2024 18:58:13 +0800 Subject: [PATCH 19/24] Update stable_diffusion.py (#7536) --- .../stablediffusion/tools/stable_diffusion.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 0c5ebc23ac..4be9207d66 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = { "seed_resize_from_w": -1, # Samplers - # "sampler_name": "DPM++ 2M", + "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", @@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): return [d['model_name'] for d in response.json()] except Exception as e: return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return [] + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d['name'] for d in response.json()] + except Exception as e: + return [] def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool): label=I18nObject(en_US=i, zh_Hans=i) ) for i in models]) ) + except: pass - + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter(name='sampler_name', + label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), + human_description=I18nObject( + en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + required=True, + default=sample_methods[0], + options=[ToolParameterOption( + value=i, + label=I18nObject(en_US=i, zh_Hans=i) + ) for i in sample_methods]) + ) return parameters From 8807d880dc50a540e74463c5c05ef277ef4bb57e Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Fri, 23 Aug 2024 19:16:30 +0800 Subject: [PATCH 20/24] Feat: add OneBot protocol tool (#7583) --- .../provider/builtin/onebot/_assets/icon.ico | Bin 0 -> 38078 bytes .../tools/provider/builtin/onebot/onebot.py | 12 ++++ .../tools/provider/builtin/onebot/onebot.yaml | 35 ++++++++++ .../provider/builtin/onebot/tools/__init__.py | 0 .../builtin/onebot/tools/send_group_msg.py | 62 ++++++++++++++++++ .../builtin/onebot/tools/send_group_msg.yaml | 46 +++++++++++++ .../builtin/onebot/tools/send_private_msg.py | 61 +++++++++++++++++ .../onebot/tools/send_private_msg.yaml | 46 +++++++++++++ 8 files changed, 262 insertions(+) create mode 100644 api/core/tools/provider/builtin/onebot/_assets/icon.ico create mode 100644 api/core/tools/provider/builtin/onebot/onebot.py create mode 100644 api/core/tools/provider/builtin/onebot/onebot.yaml create mode 100644 api/core/tools/provider/builtin/onebot/tools/__init__.py create mode 100644 api/core/tools/provider/builtin/onebot/tools/send_group_msg.py create mode 100644 api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml create mode 100644 api/core/tools/provider/builtin/onebot/tools/send_private_msg.py create mode 100644 api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml diff --git a/api/core/tools/provider/builtin/onebot/_assets/icon.ico b/api/core/tools/provider/builtin/onebot/_assets/icon.ico new file mode 100644 index 0000000000000000000000000000000000000000..1b07e965b9910b4b006bc112378a8ba0306895a8 GIT binary patch literal 38078 zcmeI52douE7soFaiWQKeqLElp5s8Xju@}S-kfNf5r-`46AYwryvBqdL8u_AVVuOeY zU@u4%6rWg8?4S@!C>AUziqc`e|8L)$Wp>$pckk}r`yR+4|F^Sqch8(TGdpwU%-N+- zXrzBtRR#TRUl@OXq0pvKD0C2%Rj3xo_h@9zp+5f_V+uvEnSO66bQXFGBZS9ft`Pu-NW>>~BvTEBM`s)PrGw}b`4-@+Op zEpabA;5AJcB@7fOYbznPEy}5jO7gCwj!pC%ef&yclJJADS|~TIQUR1jnZpERK$gh1 zsHCp)EGw=)%5Ni_Eld!86zZFhVZ6{!Xf15;?GRIkU6emWcvo01G&Dk{cZ9(LvaZ(- zdG^n8@QRXId2I#s{;^Pr@kPaCB4qqXpdENt8+ji z;1U6y_&SQyW14)IZ*4i~ezCBLYBcbN4zxk;&w+UQZK^cmJ!8(MelI12i05FS;OZ+) zclYlJ^U^tUZG#%`moHy#!0OcvVu{UskjtDLp?hr1irP7A)~qqV{q~#r@%}y(INwn3KMow zX4_Oe+~8GDI z&6+jye&ct*dAoM)%+W_5ZLYZD3iIr<&w4y7*cPl00&6sXP)U#o>7MmA>)NssW5|2& zy=Shu<{HzfQ>XQGYIWP9MGN!FE3f#nEL^zI9C5@E0sY%+?UD94@W2Dj@ZrPFyYIeh zmMvT6-(SWA>*Rp1B7y3?OLU*x_nB}1DVLlpSFSYErcLw4kG5^w)~QFUPxNT}GRLd0 zzB* zC!BDCH?~VQkH(D~=Z{bJnqo)F95Q5xnK^T&WmAqLbHV;XqBd}^QyQQ57$JA!`5!xW ztU2I-0}}UJSD(A@zPtJAE2>sQ)2B~2ty{NF#cTil_cxC|`l$K)@3OhdiNgr&AE+*o zr0e{4Qo65@yHG#<^piIyZoBQasq`hmP0N-o&6{t&De-FP>#x5yd+f1CYJS^nvyHjz zvdg?V+Vaft$o>H_Furc1l<)qobX$Rat6XW;tXZae_wJ@qqeiK9h5onRdTaCIi!b`J zEm*L?bnMtM6HjEqw)*tbPkp{QCg?1Iz07bM=y$(;lpZ4Fx}f>+!w-wvDd;u{utl*c zEn% zn0+4CCwJa?XVa%oA8(EO!3Q50Y<=c#)@WhEj$OQXvH9kkZ_K1glgte_+~BRh(zXT9 zLC#|lzN!(u`5W=j1}nwT$<1`i;wP-@d))OL+hN_su{5q+fdc^2;xtFM?;L zdGqEa%fWoW*p?%`EnckzzYRj^vvp1!Qf6X3%i6z`xpuqlwlf!Ac%fmuTD6KHGap*A zWQlqH`RC1vC!S~+%i`KSoR92-1`W#7HeIF~VOK`!pBTsq`PA5$C!Tnsq;a@IhYn`U zm@#?Qb-k+m_~Vb}#v5-mty;B8+y>aDtPgUeapHxp!)-tuD*n^VG-u8n&xaH4PjTnR z9e14P`^j8z0tU_yj(%%l8Lc)a%7 zYi6&#_Da+S_%~jE9k+5WG*MjK_Ya7zH5a!i2=Wqw~){ zKa)M?)b)F%_Y>?hP^EO{^)B`z{@u7|*yiKI$7k6Pi22}_TW(2Y7qw~A#`BfuAUDR{ z00I58Xa9hZiC7a|d+oJFbx#`%9z57AUAi<=d3oNjx^hKV)vF~=NZ zdiU;~+NOH><(JL&+izdA4X{OXh<8Efy@V0Mns9weIsg6l-#uTK?dNgFi!Qn-^`7gD z88Zy?HoWh=^Un3e8#!{MuQ$f>GHOjbhRs#!YAJdj}%ySk1_6U&xm&Pu9(#wf-) zY#IE=M;&!kNn3f(QM5j+Vba=hP7C1gOuxaQo6+fj0RxK4d)HleImMQqOsapz z4r1J}m+)s3n-aBO7AFDqx%lFXi`EBQEQk25C6Z%hoGeN5@d@_q*|Sc0+4Fehkw=oS zi@l~4{lEP-+f3={`s=T+vpziEJR@25=4z$Gp^85=HQ`5KY>ygqY+mEWjlKPt)P?BA zO6os}4RrF!CmZ(FG7>(w%=$n5^wT%#KV3th|B8%14?XlyItgp6>=_n?^VIpNYwc>-RSQ z@WT)H*4JsuaPGnO94=IB|7F)*clB%8X|6& z_Op79JMykZ7GGAaY%Jb-2*hGlWdD)+@U9#&2Wf~o8M`RT?jVo-$Fx3HdZ3A?AcK>W$5tla2- zN#jq5V+P#P*^2oqNMRxyuUSS)TliC=b7Z? zX5v-xw{5xQmPN;X;;*9So)khb0QVQrJ;*KokobMuSL!-44!!#7t17yFiSHMn*Sci2 zzcD}L^zJ+_b#xG{{)q*xc>L`_2OVU1X6lRhmd8H(?2}0Mv=P_mM9+&)Gr{Vg>Mi-vh$6O9AmYYN}{ubDBY9STFC9?Z{9yW^UO07kFmDitb>Rj zdEtc@yx5Vb^{(VBO00(&*YI1|y&GFbk|TZLhaWP(mg34u%romROmyAX@>*I$2mvEkUYyi?J&YgYr0 zFyAES#CWES6^>CH>#IuyU*D3$BR5a!U4)N>+z4OoA%`52N)HmyyN_|0Z=4)^?6KZ{ z7;Qp~k0sjqth3HCykC$chS~CsI${^I_n9-na3^6s{nH3;zS6lq1LEiT%hP^`AWBE< zhY>4iiFfH(S4Ul2N*?{b47z8%d$Hi_SuzF4{Y~jMLS~((7#TPq<}|1^gCyV8Ax%mGH?5g_Jd?4q>a#jrZ$9SC>5 zaR2@H`!Tonnxsc$AwHEji!#t(;@MO1^((m|$-PEt*2klT+zH#1^+74?PYLWl2gK8j z9zEKZi@BKgaIZ<4X0QG6#~;tjuj9+)Bpz0}sZdn^G>Utj(yY@m$ES-&B(YBBy`@sW zzsI`-?s}AYxYW6YGU@YK-bsk8R4|i8^S*+wS9DyA-0PLD5_ngsEW{YXJM;KWle{}D z*}cB*`eG(eo?J9$f%^b@=UcgaKQ8S)TtY=Kb`KCr(mieC-lH_nPUig?@u-u~C%#bj zaQIfLjrFm5vq$_Cw3j7x$bPK-J?aQA#yY+m%HCT}-*K=a#2xR`5BPL^eM+{(vJ-16_A;?&7@Jssp>y_KiNoSMV{YF^pW%1UxR!4g^X>)T7v)=Ve521@ua&tX zQ^bE~A(ifFB=_FhbQc{n$I6LEVnScU%n@IN4vFD)zemOI_{Mtl=;7_1V^fF|gG5=s8pu@x+-?=c%L+9OvvKVRNn~E4qfNP7UVwBHKIE%UonRRcQA==bOzos>oQwkcfxO4HT@g$6YrNS;a(<2_?e+4{ZJy*;Wdzm03t%*hqCi&S1i zAncr6=US`dVn>x5ZUb!7imVSBg5o6u_I)lkW-(pI{jT!EZNM6hb#mE!72;jrginEW zZD4Js^Efq?HJ@@rz*jL)*citn#v8`^fKNf^Wz}&p=9C`+{?85GtD(Pa=!E|sA9BEt zsdKO1#q?8&->od6Bl`!v1@`<`2^$)%l)&s;v3J6Ay*$;q@WDbyp|KFqw>{sGbX}LMwmvb(t(3o?fPII3%{V}C4{M1vLRwrYhE?yOO=UCcPW!xix7<5_5A+ zni~4uN6ED{*@ePQKAS?JtIwuT*v)6-Nio-6FQ(NtEs1WnX%%3TlFnzZ0&1n?*RF25 zR$}sZbJJeM$yLE~y@U6siJPwWvb;Z=yXl+P#{ecHbxD&ZSAJ3*Ht#qPv)-6 zZ{nuizUTHyH|_ReH|_R$_gM%$Bh|KGJv)BEdZzq>^~C#y=^6A3(6i|maFbd0uqkN^ d@SgEnrrxt(%V;;Qsb$oav^$0b#uk6f`9C{jz#jkr literal 0 HcmV?d00001 diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py new file mode 100644 index 0000000000..42f321e919 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -0,0 +1,12 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class OneBotProvider(BuiltinToolProviderController): + + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + + if not credentials.get("ob11_http_url"): + raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') diff --git a/api/core/tools/provider/builtin/onebot/onebot.yaml b/api/core/tools/provider/builtin/onebot/onebot.yaml new file mode 100644 index 0000000000..1922adc4de --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.yaml @@ -0,0 +1,35 @@ +identity: + author: RockChinQ + name: onebot + label: + en_US: OneBot v11 Protocol + zh_Hans: OneBot v11 协议 + description: + en_US: Unofficial OneBot v11 Protocol Tool + zh_Hans: 非官方 OneBot v11 协议工具 + icon: icon.ico +credentials_for_provider: + ob11_http_url: + type: text-input + required: true + label: + en_US: HTTP URL + zh_Hans: HTTP URL + description: + en_US: Forward HTTP URL of OneBot v11 + zh_Hans: OneBot v11 正向 HTTP URL + help: + en_US: Fill this with the HTTP URL of your OneBot server + zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL + access_token: + type: secret-input + required: false + label: + en_US: Access Token + zh_Hans: 访问令牌 + description: + en_US: Access Token for OneBot v11 Protocol + zh_Hans: OneBot 协议访问令牌 + help: + en_US: Fill this if you set a access token in your OneBot server + zh_Hans: 如果你在 OneBot 服务器中设置了 access token,请填写此项 diff --git a/api/core/tools/provider/builtin/onebot/tools/__init__.py b/api/core/tools/provider/builtin/onebot/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py new file mode 100644 index 0000000000..802bb7b610 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -0,0 +1,62 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendGroupMsg(BuiltinTool): + """OneBot v11 Tool: Send Group Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_group_id = tool_parameters.get('group_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + + resp = requests.post( + f'{self.runtime.credentials['ob11_http_url']}/send_group_msg', + json={ + 'group_id': send_group_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send group message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send group message: {e}' + } + ) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml new file mode 100644 index 0000000000..64beaa8545 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_group_msg + author: RockChinQ + label: + en_US: Send Group Message + zh_Hans: 发送群消息 +description: + human: + en_US: Send a message to a group + zh_Hans: 发送消息到群聊 + llm: A tool for sending a message segment to a group +parameters: + - name: group_id + type: number + required: true + label: + en_US: Target Group ID + zh_Hans: 目标群 ID + human_description: + en_US: The group ID of the target group + zh_Hans: 目标群的群 ID + llm_description: The group ID of the target group + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py new file mode 100644 index 0000000000..a11c7a46c0 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -0,0 +1,61 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendPrivateMsg(BuiltinTool): + """OneBot v11 Tool: Send Private Message""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + # Get parameters + send_user_id = tool_parameters.get('user_id', '') + + message = tool_parameters.get('message', '') + if not message: + return self.create_json_message( + { + 'error': 'Message is empty.' + } + ) + + auto_escape = tool_parameters.get('auto_escape', False) + + try: + resp = requests.post( + f'{self.runtime.credentials['ob11_http_url']}/send_private_msg', + json={ + 'user_id': send_user_id, + 'message': message, + 'auto_escape': auto_escape + }, + headers={ + 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] + } + ) + + if resp.status_code != 200: + return self.create_json_message( + { + 'error': f'Failed to send private message: {resp.text}' + } + ) + + return self.create_json_message( + { + 'response': resp.json() + } + ) + except Exception as e: + return self.create_json_message( + { + 'error': f'Failed to send private message: {e}' + } + ) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml new file mode 100644 index 0000000000..8200ce4a83 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_private_msg + author: RockChinQ + label: + en_US: Send Private Message + zh_Hans: 发送私聊消息 +description: + human: + en_US: Send a private message to a user + zh_Hans: 发送私聊消息给用户 + llm: A tool for sending a message segment to a user in private chat +parameters: + - name: user_id + type: number + required: true + label: + en_US: Target User ID + zh_Hans: 目标用户 ID + human_description: + en_US: The user ID of the target user + zh_Hans: 目标用户的用户 ID + llm_description: The user ID of the target user + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form From e3d7c7c6f961a1b83feda718390e388a7ca2d8a5 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Fri, 23 Aug 2024 22:22:42 +0800 Subject: [PATCH 21/24] fix(onebot): use yarl to format url (#7589) --- .../tools/provider/builtin/onebot/tools/send_group_msg.py | 4 +++- .../tools/provider/builtin/onebot/tools/send_private_msg.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py index 802bb7b610..2a1a9f86de 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -1,6 +1,7 @@ from typing import Any, Union import requests +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -29,9 +30,10 @@ class SendGroupMsg(BuiltinTool): auto_escape = tool_parameters.get('auto_escape', False) try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' resp = requests.post( - f'{self.runtime.credentials['ob11_http_url']}/send_group_msg', + url, json={ 'group_id': send_group_id, 'message': message, diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py index a11c7a46c0..8ef4d72ab6 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -1,6 +1,7 @@ from typing import Any, Union import requests +from yarl import URL from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -29,8 +30,10 @@ class SendPrivateMsg(BuiltinTool): auto_escape = tool_parameters.get('auto_escape', False) try: + url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + resp = requests.post( - f'{self.runtime.credentials['ob11_http_url']}/send_private_msg', + url, json={ 'user_id': send_user_id, 'message': message, From 3ace01cfb35b3e473a53993ddc35b1e39e2aee0b Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 23 Aug 2024 22:40:07 +0800 Subject: [PATCH 22/24] chore: cleanup and rearrange unclassified configs into feature config groups (#7586) --- api/configs/app_config.py | 37 +------- api/configs/deploy/__init__.py | 5 ++ api/configs/feature/__init__.py | 86 ++++++++++++++++++- api/core/workflow/nodes/code/code_node.py | 43 +++++----- .../nodes/http_request/http_executor.py | 23 ++--- api/poetry.lock | 9 +- api/pyproject.toml | 2 +- 7 files changed, 125 insertions(+), 80 deletions(-) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 494e256442..ff8c77de48 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,4 +1,3 @@ -from pydantic import Field, computed_field from pydantic_settings import SettingsConfigDict from configs.deploy import DeploymentConfig @@ -24,8 +23,6 @@ class DifyConfig( # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, ): - DEBUG: bool = Field(default=False, description='whether to enable debug mode.') - model_config = SettingsConfigDict( # read from dotenv format config file env_file='.env', @@ -35,33 +32,7 @@ class DifyConfig( extra='ignore', ) - CODE_MAX_NUMBER: int = 9223372036854775807 - CODE_MIN_NUMBER: int = -9223372036854775808 - CODE_MAX_DEPTH: int = 5 - CODE_MAX_PRECISION: int = 20 - CODE_MAX_STRING_LENGTH: int = 80000 - CODE_MAX_STRING_ARRAY_LENGTH: int = 30 - CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 - CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000 - - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300 - HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600 - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' - - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' - - SSRF_PROXY_HTTP_URL: str | None = None - SSRF_PROXY_HTTPS_URL: str | None = None - - MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') - - MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.') + # Before adding any config, + # please consider to arrange it in the proper config group of existed or added + # for better readability and maintainability. + # Thanks for your concentration and consideration. diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 219b315784..c99e3d21d2 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -11,6 +11,11 @@ class DeploymentConfig(BaseSettings): default='langgenius/dify', ) + DEBUG: bool = Field( + description='whether to enable debug mode.', + default=False, + ) + TESTING: bool = Field( description='', default=False, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 27de2c0461..7f36abf7a6 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field +from pydantic import AliasChoices, Field, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -52,6 +52,46 @@ class CodeExecutionSandboxConfig(BaseSettings): default='dify-sandbox', ) + CODE_MAX_NUMBER: PositiveInt = Field( + description='max depth for code execution', + default=9223372036854775807, + ) + + CODE_MIN_NUMBER: NegativeInt = Field( + description='', + default=-9223372036854775807, + ) + + CODE_MAX_DEPTH: PositiveInt = Field( + description='max depth for code execution', + default=5, + ) + + CODE_MAX_PRECISION: PositiveInt = Field( + description='max precision digits for float type in code execution', + default=20, + ) + + CODE_MAX_STRING_LENGTH: PositiveInt = Field( + description='max string length for code execution', + default=80000, + ) + + CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( + description='', + default=30, + ) + + CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( + description='', + default=30, + ) + + CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( + description='', + default=1000, + ) + class EndpointConfig(BaseSettings): """ @@ -157,6 +197,41 @@ class HttpConfig(BaseSettings): def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',') + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field( + description='', + default=300, + ) + + HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field( + description='', + default=600, + ) + + HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field( + description='', + default=600, + ) + + HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( + description='', + default=10 * 1024 * 1024, + ) + + HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( + description='', + default=1 * 1024 * 1024, + ) + + SSRF_PROXY_HTTP_URL: Optional[str] = Field( + description='HTTP URL for SSRF proxy', + default=None, + ) + + SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + description='HTTPS URL for SSRF proxy', + default=None, + ) + class InnerAPIConfig(BaseSettings): """ @@ -255,6 +330,11 @@ class WorkflowConfig(BaseSettings): default=5, ) + MAX_VARIABLE_SIZE: PositiveInt = Field( + description='The maximum size in bytes of a variable. default to 5KB.', + default=5 * 1024, + ) + class OAuthConfig(BaseSettings): """ @@ -291,8 +371,7 @@ class ModerationConfig(BaseSettings): Moderation in app configs. """ - # todo: to be clarified in usage and unit - OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field( + MODERATION_BUFFER_SIZE: PositiveInt = Field( description='buffer size for moderation', default=300, ) @@ -444,7 +523,6 @@ class CeleryBeatConfig(BaseSettings): class PositionConfig(BaseSettings): - POSITION_PROVIDER_PINS: str = Field( description='The heads of model providers', default='', diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 71adf18a51..17554d3db4 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,15 +11,6 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_NUMBER = dify_config.CODE_MAX_NUMBER -MIN_NUMBER = dify_config.CODE_MIN_NUMBER -MAX_PRECISION = dify_config.CODE_MAX_PRECISION -MAX_DEPTH = dify_config.CODE_MAX_DEPTH -MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH -MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH -MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH -MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH - class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -97,8 +88,9 @@ class CodeNode(BaseNode): else: raise ValueError(f"Output variable `{variable}` must be a string") - if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters') + if len(value) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + raise ValueError(f'The length of output variable `{variable}` must be' + f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} characters') return value.replace('\x00', '') @@ -115,13 +107,15 @@ class CodeNode(BaseNode): else: raise ValueError(f"Output variable `{variable}` must be a number") - if value > MAX_NUMBER or value < MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.') + if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + raise ValueError(f'Output variable `{variable}` is out of range,' + f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.') + if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError(f'Output variable `{variable}` has too high precision,' + f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') return value @@ -134,8 +128,8 @@ class CodeNode(BaseNode): :param output_schema: output schema :return: """ - if depth > MAX_DEPTH: - raise ValueError("Depth limit reached, object too deep.") + if depth > dify_config.CODE_MAX_DEPTH: + raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -235,9 +229,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' ) transformed_result[output_name] = [ @@ -257,9 +252,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' ) transformed_result[output_name] = [ @@ -279,9 +275,10 @@ class CodeNode(BaseNode): f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' ) else: - if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH: + if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.' + f'The length of output variable `{prefix}{dot}{output_name}` must be' + f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' ) for i, value in enumerate(result[output_name]): diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index db18bd00b2..d16bff58bd 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -18,11 +18,6 @@ from core.workflow.nodes.http_request.entities import ( ) from core.workflow.utils.variable_template_parser import VariableTemplateParser -MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE -READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE -MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE -READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE - class HttpExecutorResponse: headers: dict[str, str] @@ -237,16 +232,14 @@ class HttpExecutor: else: raise ValueError(f'Invalid response type {type(response)}') - if executor_response.is_file: - if executor_response.size > MAX_BINARY_SIZE: - raise ValueError( - f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.' - ) - else: - if executor_response.size > MAX_TEXT_SIZE: - raise ValueError( - f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.' - ) + threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + if executor_response.size > threshold_size: + raise ValueError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) return executor_response diff --git a/api/poetry.lock b/api/poetry.lock index a68eadefcf..0a8919a30a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -6372,13 +6372,13 @@ semver = ["semver (>=3.0.2)"] [[package]] name = "pydantic-settings" -version = "2.3.4" +version = "2.4.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"}, - {file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"}, + {file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"}, + {file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"}, ] [package.dependencies] @@ -6386,6 +6386,7 @@ pydantic = ">=2.7.0" python-dotenv = ">=0.21.0" [package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] toml = ["tomli (>=2.0.1)"] yaml = ["pyyaml (>=6.0.1)"] @@ -9633,4 +9634,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "69c20af8ecacced3cca092662223a1511acaf65cb2616a5a1e38b498223463e0" +content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02" diff --git a/api/pyproject.toml b/api/pyproject.toml index 8e050c4101..47b638573b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -162,7 +162,7 @@ pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.8.2" -pydantic-settings = "~2.3.4" +pydantic-settings = "~2.4.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" From 2da63654e545f56695778c191b8b68f94775fd36 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 23 Aug 2024 23:46:01 +0800 Subject: [PATCH 23/24] chore(api/configs): apply ruff reformat (#7590) --- api/configs/__init__.py | 2 +- api/configs/app_config.py | 6 +- api/configs/deploy/__init__.py | 17 +- api/configs/enterprise/__init__.py | 7 +- api/configs/extra/notion_config.py | 11 +- api/configs/extra/sentry_config.py | 7 +- api/configs/feature/__init__.py | 258 +++++++++--------- .../feature/hosted_service/__init__.py | 101 ++++--- api/configs/middleware/__init__.py | 97 +++---- api/configs/middleware/cache/redis_config.py | 15 +- .../storage/aliyun_oss_storage_config.py | 12 +- .../storage/amazon_s3_storage_config.py | 16 +- .../storage/azure_blob_storage_config.py | 8 +- .../storage/google_cloud_storage_config.py | 4 +- .../middleware/storage/oci_storage_config.py | 11 +- .../storage/tencent_cos_storage_config.py | 10 +- .../middleware/vdb/analyticdb_config.py | 39 ++- api/configs/middleware/vdb/chroma_config.py | 12 +- api/configs/middleware/vdb/milvus_config.py | 14 +- api/configs/middleware/vdb/myscale_config.py | 23 +- .../middleware/vdb/opensearch_config.py | 10 +- api/configs/middleware/vdb/oracle_config.py | 10 +- api/configs/middleware/vdb/pgvector_config.py | 10 +- .../middleware/vdb/pgvectors_config.py | 10 +- api/configs/middleware/vdb/qdrant_config.py | 10 +- api/configs/middleware/vdb/relyt_config.py | 12 +- .../middleware/vdb/tencent_vector_config.py | 16 +- .../middleware/vdb/tidb_vector_config.py | 10 +- api/configs/middleware/vdb/weaviate_config.py | 8 +- api/configs/packaging/__init__.py | 6 +- api/pyproject.toml | 1 - 31 files changed, 388 insertions(+), 385 deletions(-) diff --git a/api/configs/__init__.py b/api/configs/__init__.py index c0e28c34e1..3a172601c9 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() \ No newline at end of file +dify_config = DifyConfig() diff --git a/api/configs/app_config.py b/api/configs/app_config.py index ff8c77de48..61de73c868 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -25,11 +25,11 @@ class DifyConfig( ): model_config = SettingsConfigDict( # read from dotenv format config file - env_file='.env', - env_file_encoding='utf-8', + env_file=".env", + env_file_encoding="utf-8", frozen=True, # ignore extra attributes - extra='ignore', + extra="ignore", ) # Before adding any config, diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index c99e3d21d2..10271483c4 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -6,27 +6,28 @@ class DeploymentConfig(BaseSettings): """ Deployment configs """ + APPLICATION_NAME: str = Field( - description='application name', - default='langgenius/dify', + description="application name", + default="langgenius/dify", ) DEBUG: bool = Field( - description='whether to enable debug mode.', + description="whether to enable debug mode.", default=False, ) TESTING: bool = Field( - description='', + description="", default=False, ) EDITION: str = Field( - description='deployment edition', - default='SELF_HOSTED', + description="deployment edition", + default="SELF_HOSTED", ) DEPLOY_ENV: str = Field( - description='deployment environment, default to PRODUCTION.', - default='PRODUCTION', + description="deployment environment, default to PRODUCTION.", + default="PRODUCTION", ) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index b5d884e10e..c661593a44 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings): Enterprise feature configs. **Before using, please contact business@dify.ai by email to inquire about licensing matters.** """ + ENTERPRISE_ENABLED: bool = Field( - description='whether to enable enterprise features.' - 'Before using, please contact business@dify.ai by email to inquire about licensing matters.', + description="whether to enable enterprise features." + "Before using, please contact business@dify.ai by email to inquire about licensing matters.", default=False, ) CAN_REPLACE_LOGO: bool = Field( - description='whether to allow replacing enterprise logo.', + description="whether to allow replacing enterprise logo.", default=False, ) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index b77e8adaae..bd1268fa45 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -8,27 +8,28 @@ class NotionConfig(BaseSettings): """ Notion integration configs """ + NOTION_CLIENT_ID: Optional[str] = Field( - description='Notion client ID', + description="Notion client ID", default=None, ) NOTION_CLIENT_SECRET: Optional[str] = Field( - description='Notion client secret key', + description="Notion client secret key", default=None, ) NOTION_INTEGRATION_TYPE: Optional[str] = Field( - description='Notion integration type, default to None, available values: internal.', + description="Notion integration type, default to None, available values: internal.", default=None, ) NOTION_INTERNAL_SECRET: Optional[str] = Field( - description='Notion internal secret key', + description="Notion internal secret key", default=None, ) NOTION_INTEGRATION_TOKEN: Optional[str] = Field( - description='Notion integration token', + description="Notion integration token", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index e6517f730a..ea9ea60ffb 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -8,17 +8,18 @@ class SentryConfig(BaseSettings): """ Sentry configs """ + SENTRY_DSN: Optional[str] = Field( - description='Sentry DSN', + description="Sentry DSN", default=None, ) SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry trace sample rate', + description="Sentry trace sample rate", default=1.0, ) SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry profiles sample rate', + description="Sentry profiles sample rate", default=1.0, ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 7f36abf7a6..46ae7a0bc8 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings): """ Secret Key configs """ + SECRET_KEY: Optional[str] = Field( - description='Your App secret key will be used for securely signing the session cookie' - 'Make sure you are changing this key for your deployment with a strong key.' - 'You can generate a strong key using `openssl rand -base64 42`.' - 'Alternatively you can set it with `SECRET_KEY` environment variable.', + description="Your App secret key will be used for securely signing the session cookie" + "Make sure you are changing this key for your deployment with a strong key." + "You can generate a strong key using `openssl rand -base64 42`." + "Alternatively you can set it with `SECRET_KEY` environment variable.", default=None, ) RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( - description='Expiry time in hours for reset token', + description="Expiry time in hours for reset token", default=24, ) @@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings): """ App Execution configs """ + APP_MAX_EXECUTION_TIME: PositiveInt = Field( - description='execution timeout in seconds for app execution', + description="execution timeout in seconds for app execution", default=1200, ) APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( - description='max active request per app, 0 means unlimited', + description="max active request per app, 0 means unlimited", default=0, ) @@ -42,53 +44,54 @@ class CodeExecutionSandboxConfig(BaseSettings): """ Code Execution Sandbox configs """ + CODE_EXECUTION_ENDPOINT: str = Field( - description='endpoint URL of code execution servcie', - default='http://sandbox:8194', + description="endpoint URL of code execution servcie", + default="http://sandbox:8194", ) CODE_EXECUTION_API_KEY: str = Field( - description='API key for code execution service', - default='dify-sandbox', + description="API key for code execution service", + default="dify-sandbox", ) CODE_MAX_NUMBER: PositiveInt = Field( - description='max depth for code execution', + description="max depth for code execution", default=9223372036854775807, ) CODE_MIN_NUMBER: NegativeInt = Field( - description='', + description="", default=-9223372036854775807, ) CODE_MAX_DEPTH: PositiveInt = Field( - description='max depth for code execution', + description="max depth for code execution", default=5, ) CODE_MAX_PRECISION: PositiveInt = Field( - description='max precision digits for float type in code execution', + description="max precision digits for float type in code execution", default=20, ) CODE_MAX_STRING_LENGTH: PositiveInt = Field( - description='max string length for code execution', + description="max string length for code execution", default=80000, ) CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( - description='', + description="", default=30, ) CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( - description='', + description="", default=30, ) CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( - description='', + description="", default=1000, ) @@ -97,28 +100,27 @@ class EndpointConfig(BaseSettings): """ Module URL configs """ + CONSOLE_API_URL: str = Field( - description='The backend URL prefix of the console API.' - 'used to concatenate the login authorization callback or notion integration callback.', - default='', + description="The backend URL prefix of the console API." + "used to concatenate the login authorization callback or notion integration callback.", + default="", ) CONSOLE_WEB_URL: str = Field( - description='The front-end URL prefix of the console web.' - 'used to concatenate some front-end addresses and for CORS configuration use.', - default='', + description="The front-end URL prefix of the console web." + "used to concatenate some front-end addresses and for CORS configuration use.", + default="", ) SERVICE_API_URL: str = Field( - description='Service API Url prefix.' - 'used to display Service API Base Url to the front-end.', - default='', + description="Service API Url prefix." "used to display Service API Base Url to the front-end.", + default="", ) APP_WEB_URL: str = Field( - description='WebApp Url prefix.' - 'used to display WebAPP API Base Url to the front-end.', - default='', + description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", + default="", ) @@ -126,17 +128,18 @@ class FileAccessConfig(BaseSettings): """ File Access configs """ + FILES_URL: str = Field( - description='File preview or download Url prefix.' - ' used to display File preview or download Url to the front-end or as Multi-model inputs;' - 'Url is signed and has expiration time.', - validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'), + description="File preview or download Url prefix." + " used to display File preview or download Url to the front-end or as Multi-model inputs;" + "Url is signed and has expiration time.", + validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"), alias_priority=1, - default='', + default="", ) FILES_ACCESS_TIMEOUT: int = Field( - description='timeout in seconds for file accessing', + description="timeout in seconds for file accessing", default=300, ) @@ -145,23 +148,24 @@ class FileUploadConfig(BaseSettings): """ File Uploading configs """ + UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='size limit in Megabytes for uploading files', + description="size limit in Megabytes for uploading files", default=15, ) UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( - description='batch size limit for uploading files', + description="batch size limit for uploading files", default=5, ) UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='image file size limit in Megabytes for uploading files', + description="image file size limit in Megabytes for uploading files", default=10, ) BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( - description='', # todo: to be clarified + description="", # todo: to be clarified default=20, ) @@ -170,65 +174,66 @@ class HttpConfig(BaseSettings): """ HTTP configs """ + API_COMPRESSION_ENABLED: bool = Field( - description='whether to enable HTTP response compression of gzip', + description="whether to enable HTTP response compression of gzip", default=False, ) inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'), - default='', + description="", + validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), + default="", ) @computed_field @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',') + return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'), - default='*', + description="", + validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"), + default="*", ) @computed_field @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',') + return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field( - description='', + description="", default=300, ) HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field( - description='', + description="", default=600, ) HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field( - description='', + description="", default=600, ) HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( - description='', + description="", default=10 * 1024 * 1024, ) HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( - description='', + description="", default=1 * 1024 * 1024, ) SSRF_PROXY_HTTP_URL: Optional[str] = Field( - description='HTTP URL for SSRF proxy', + description="HTTP URL for SSRF proxy", default=None, ) SSRF_PROXY_HTTPS_URL: Optional[str] = Field( - description='HTTPS URL for SSRF proxy', + description="HTTPS URL for SSRF proxy", default=None, ) @@ -237,13 +242,14 @@ class InnerAPIConfig(BaseSettings): """ Inner API configs """ + INNER_API: bool = Field( - description='whether to enable the inner API', + description="whether to enable the inner API", default=False, ) INNER_API_KEY: Optional[str] = Field( - description='The inner API key is used to authenticate the inner API', + description="The inner API key is used to authenticate the inner API", default=None, ) @@ -254,28 +260,27 @@ class LoggingConfig(BaseSettings): """ LOG_LEVEL: str = Field( - description='Log output level, default to INFO.' - 'It is recommended to set it to ERROR for production.', - default='INFO', + description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", + default="INFO", ) LOG_FILE: Optional[str] = Field( - description='logging output file path', + description="logging output file path", default=None, ) LOG_FORMAT: str = Field( - description='log format', - default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s', + description="log format", + default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) LOG_DATEFORMAT: Optional[str] = Field( - description='log date format', + description="log date format", default=None, ) LOG_TZ: Optional[str] = Field( - description='specify log timezone, eg: America/New_York', + description="specify log timezone, eg: America/New_York", default=None, ) @@ -284,8 +289,9 @@ class ModelLoadBalanceConfig(BaseSettings): """ Model load balance configs """ + MODEL_LB_ENABLED: bool = Field( - description='whether to enable model load balancing', + description="whether to enable model load balancing", default=False, ) @@ -294,8 +300,9 @@ class BillingConfig(BaseSettings): """ Platform Billing Configurations """ + BILLING_ENABLED: bool = Field( - description='whether to enable billing', + description="whether to enable billing", default=False, ) @@ -304,9 +311,10 @@ class UpdateConfig(BaseSettings): """ Update configs """ + CHECK_UPDATE_URL: str = Field( - description='url for checking updates', - default='https://updates.dify.ai', + description="url for checking updates", + default="https://updates.dify.ai", ) @@ -316,22 +324,22 @@ class WorkflowConfig(BaseSettings): """ WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field( - description='max execution steps in single workflow execution', + description="max execution steps in single workflow execution", default=500, ) WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( - description='max execution time in seconds in single workflow execution', + description="max execution time in seconds in single workflow execution", default=1200, ) WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( - description='max depth of calling in single workflow execution', + description="max depth of calling in single workflow execution", default=5, ) MAX_VARIABLE_SIZE: PositiveInt = Field( - description='The maximum size in bytes of a variable. default to 5KB.', + description="The maximum size in bytes of a variable. default to 5KB.", default=5 * 1024, ) @@ -340,28 +348,29 @@ class OAuthConfig(BaseSettings): """ oauth configs """ + OAUTH_REDIRECT_PATH: str = Field( - description='redirect path for OAuth', - default='/console/api/oauth/authorize', + description="redirect path for OAuth", + default="/console/api/oauth/authorize", ) GITHUB_CLIENT_ID: Optional[str] = Field( - description='GitHub client id for OAuth', + description="GitHub client id for OAuth", default=None, ) GITHUB_CLIENT_SECRET: Optional[str] = Field( - description='GitHub client secret key for OAuth', + description="GitHub client secret key for OAuth", default=None, ) GOOGLE_CLIENT_ID: Optional[str] = Field( - description='Google client id for OAuth', + description="Google client id for OAuth", default=None, ) GOOGLE_CLIENT_SECRET: Optional[str] = Field( - description='Google client secret key for OAuth', + description="Google client secret key for OAuth", default=None, ) @@ -372,7 +381,7 @@ class ModerationConfig(BaseSettings): """ MODERATION_BUFFER_SIZE: PositiveInt = Field( - description='buffer size for moderation', + description="buffer size for moderation", default=300, ) @@ -383,7 +392,7 @@ class ToolConfig(BaseSettings): """ TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field( - description='max age in seconds for tool icon caching', + description="max age in seconds for tool icon caching", default=3600, ) @@ -394,52 +403,52 @@ class MailConfig(BaseSettings): """ MAIL_TYPE: Optional[str] = Field( - description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.', + description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.", default=None, ) MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( - description='default email address for sending from ', + description="default email address for sending from ", default=None, ) RESEND_API_KEY: Optional[str] = Field( - description='API key for Resend', + description="API key for Resend", default=None, ) RESEND_API_URL: Optional[str] = Field( - description='API URL for Resend', + description="API URL for Resend", default=None, ) SMTP_SERVER: Optional[str] = Field( - description='smtp server host', + description="smtp server host", default=None, ) SMTP_PORT: Optional[int] = Field( - description='smtp server port', + description="smtp server port", default=465, ) SMTP_USERNAME: Optional[str] = Field( - description='smtp server username', + description="smtp server username", default=None, ) SMTP_PASSWORD: Optional[str] = Field( - description='smtp server password', + description="smtp server password", default=None, ) SMTP_USE_TLS: bool = Field( - description='whether to use TLS connection to smtp server', + description="whether to use TLS connection to smtp server", default=False, ) SMTP_OPPORTUNISTIC_TLS: bool = Field( - description='whether to use opportunistic TLS connection to smtp server', + description="whether to use opportunistic TLS connection to smtp server", default=False, ) @@ -450,22 +459,22 @@ class RagEtlConfig(BaseSettings): """ ETL_TYPE: str = Field( - description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ', - default='dify', + description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ", + default="dify", ) KEYWORD_DATA_SOURCE_TYPE: str = Field( - description='source type for keyword data, default to `database`, available values are `database` .', - default='database', + description="source type for keyword data, default to `database`, available values are `database` .", + default="database", ) UNSTRUCTURED_API_URL: Optional[str] = Field( - description='API URL for Unstructured', + description="API URL for Unstructured", default=None, ) UNSTRUCTURED_API_KEY: Optional[str] = Field( - description='API key for Unstructured', + description="API key for Unstructured", default=None, ) @@ -476,12 +485,12 @@ class DataSetConfig(BaseSettings): """ CLEAN_DAY_SETTING: PositiveInt = Field( - description='interval in days for cleaning up dataset', + description="interval in days for cleaning up dataset", default=30, ) DATASET_OPERATOR_ENABLED: bool = Field( - description='whether to enable dataset operator', + description="whether to enable dataset operator", default=False, ) @@ -492,7 +501,7 @@ class WorkspaceConfig(BaseSettings): """ INVITE_EXPIRY_HOURS: PositiveInt = Field( - description='workspaces invitation expiration in hours', + description="workspaces invitation expiration in hours", default=72, ) @@ -503,79 +512,79 @@ class IndexingConfig(BaseSettings): """ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( - description='max segmentation token length for indexing', + description="max segmentation token length for indexing", default=1000, ) class ImageFormatConfig(BaseSettings): MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( - description='multi model send image format, support base64, url, default is base64', - default='base64', + description="multi model send image format, support base64, url, default is base64", + default="base64", ) class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( - description='the time of the celery scheduler, default to 1 day', + description="the time of the celery scheduler, default to 1 day", default=1, ) class PositionConfig(BaseSettings): POSITION_PROVIDER_PINS: str = Field( - description='The heads of model providers', - default='', + description="The heads of model providers", + default="", ) POSITION_PROVIDER_INCLUDES: str = Field( - description='The included model providers', - default='', + description="The included model providers", + default="", ) POSITION_PROVIDER_EXCLUDES: str = Field( - description='The excluded model providers', - default='', + description="The excluded model providers", + default="", ) POSITION_TOOL_PINS: str = Field( - description='The heads of tools', - default='', + description="The heads of tools", + default="", ) POSITION_TOOL_INCLUDES: str = Field( - description='The included tools', - default='', + description="The included tools", + default="", ) POSITION_TOOL_EXCLUDES: str = Field( - description='The excluded tools', - default='', + description="The excluded tools", + default="", ) @computed_field def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: - return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] @computed_field def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: - return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''} + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} @computed_field def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: - return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''} + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} @computed_field def POSITION_TOOL_PINS_LIST(self) -> list[str]: - return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] + return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] @computed_field def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: - return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''} + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} @computed_field def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: - return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''} + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} class FeatureConfig( @@ -603,7 +612,6 @@ class FeatureConfig( WorkflowConfig, WorkspaceConfig, PositionConfig, - # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 88fe188587..f269d0ab9c 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings): """ HOSTED_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_TRIAL_MODELS: str = Field( - description='', - default='gpt-3.5-turbo,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-instruct,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'text-davinci-003', + description="", + default="gpt-3.5-turbo," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-instruct," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "text-davinci-003", ) HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_OPENAI_PAID_MODELS: str = Field( - description='', - default='gpt-4,' - 'gpt-4-turbo-preview,' - 'gpt-4-turbo-2024-04-09,' - 'gpt-4-1106-preview,' - 'gpt-4-0125-preview,' - 'gpt-3.5-turbo,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'gpt-3.5-turbo-instruct,' - 'text-davinci-003', + description="", + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-3.5-turbo," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "gpt-3.5-turbo-instruct," + "text-davinci-003", ) @@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings): """ HOSTED_AZURE_OPENAI_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=200, ) @@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings): """ HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( - description='', + description="", default=None, ) HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="", default=600000, ) HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings): """ HOSTED_MINIMAX_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings): """ HOSTED_SPARK_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings): """ HOSTED_ZHIPUAI_ENABLED: bool = Field( - description='', + description="", default=False, ) @@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings): """ HOSTED_MODERATION_ENABLED: bool = Field( - description='', + description="", default=False, ) HOSTED_MODERATION_PROVIDERS: str = Field( - description='', - default='', + description="", + default="", ) @@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings): """ HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( - description='the mode for fetching app templates,' - ' default to remote,' - ' available values: remote, db, builtin', - default='remote', + description="the mode for fetching app templates," + " default to remote," + " available values: remote, db, builtin", + default="remote", ) HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( - description='the domain for fetching remote app templates', - default='https://tmpl.dify.ai', + description="the domain for fetching remote app templates", + default="https://tmpl.dify.ai", ) @@ -202,7 +202,6 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, - # moderation HostedModerationConfig, ): diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 07688e9aeb..05e9b8f7a6 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -28,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig class StorageConfig(BaseSettings): STORAGE_TYPE: str = Field( - description='storage type,' - ' default to `local`,' - ' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.', - default='local', + description="storage type," + " default to `local`," + " available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.", + default="local", ) STORAGE_LOCAL_PATH: str = Field( - description='local storage path', - default='storage', + description="local storage path", + default="storage", ) class VectorStoreConfig(BaseSettings): VECTOR_STORE: Optional[str] = Field( - description='vector store type', + description="vector store type", default=None, ) class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( - description='keyword store type', - default='jieba', + description="keyword store type", + default="jieba", ) class DatabaseConfig: DB_HOST: str = Field( - description='db host', - default='localhost', + description="db host", + default="localhost", ) DB_PORT: PositiveInt = Field( - description='db port', + description="db port", default=5432, ) DB_USERNAME: str = Field( - description='db username', - default='postgres', + description="db username", + default="postgres", ) DB_PASSWORD: str = Field( - description='db password', - default='', + description="db password", + default="", ) DB_DATABASE: str = Field( - description='db database', - default='dify', + description="db database", + default="dify", ) DB_CHARSET: str = Field( - description='db charset', - default='', + description="db charset", + default="", ) DB_EXTRAS: str = Field( - description='db extras options. Example: keepalives_idle=60&keepalives=1', - default='', + description="db extras options. Example: keepalives_idle=60&keepalives=1", + default="", ) SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( - description='db uri scheme', - default='postgresql', + description="db uri scheme", + default="postgresql", ) @computed_field @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( - f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" - if self.DB_CHARSET - else self.DB_EXTRAS + f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" - return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" - f"{db_extras}") + return ( + f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{db_extras}" + ) SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( - description='pool size of SqlAlchemy', + description="pool size of SqlAlchemy", default=30, ) SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( - description='max overflows for SqlAlchemy', + description="max overflows for SqlAlchemy", default=10, ) SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( - description='SqlAlchemy pool recycle', + description="SqlAlchemy pool recycle", default=3600, ) SQLALCHEMY_POOL_PRE_PING: bool = Field( - description='whether to enable pool pre-ping in SqlAlchemy', + description="whether to enable pool pre-ping in SqlAlchemy", default=False, ) SQLALCHEMY_ECHO: bool | str = Field( - description='whether to enable SqlAlchemy echo', + description="whether to enable SqlAlchemy echo", default=False, ) @@ -137,35 +137,38 @@ class DatabaseConfig: @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { - 'pool_size': self.SQLALCHEMY_POOL_SIZE, - 'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW, - 'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE, - 'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING, - 'connect_args': {'options': '-c timezone=UTC'}, + "pool_size": self.SQLALCHEMY_POOL_SIZE, + "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, + "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, + "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, + "connect_args": {"options": "-c timezone=UTC"}, } class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description='Celery backend, available values are `database`, `redis`', - default='database', + description="Celery backend, available values are `database`, `redis`", + default="database", ) CELERY_BROKER_URL: Optional[str] = Field( - description='CELERY_BROKER_URL', + description="CELERY_BROKER_URL", default=None, ) @computed_field @property def CELERY_RESULT_BACKEND(self) -> str | None: - return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ - if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL + return ( + "db+{}".format(self.SQLALCHEMY_DATABASE_URI) + if self.CELERY_BACKEND == "database" + else self.CELERY_BROKER_URL + ) @computed_field @property def BROKER_USE_SSL(self) -> bool: - return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False + return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False class MiddlewareConfig( @@ -174,7 +177,6 @@ class MiddlewareConfig( DatabaseConfig, KeywordStoreConfig, RedisConfig, - # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, @@ -183,7 +185,6 @@ class MiddlewareConfig( TencentCloudCOSStorageConfig, S3StorageConfig, OCIStorageConfig, - # configs of vdb and vdb providers VectorStoreConfig, AnalyticdbConfig, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 436ba5d4c0..cacdaf6fb6 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -8,32 +8,33 @@ class RedisConfig(BaseSettings): """ Redis configs """ + REDIS_HOST: str = Field( - description='Redis host', - default='localhost', + description="Redis host", + default="localhost", ) REDIS_PORT: PositiveInt = Field( - description='Redis port', + description="Redis port", default=6379, ) REDIS_USERNAME: Optional[str] = Field( - description='Redis username', + description="Redis username", default=None, ) REDIS_PASSWORD: Optional[str] = Field( - description='Redis password', + description="Redis password", default=None, ) REDIS_DB: NonNegativeInt = Field( - description='Redis database id, default to 0', + description="Redis database id, default to 0", default=0, ) REDIS_USE_SSL: bool = Field( - description='whether to use SSL for Redis connection', + description="whether to use SSL for Redis connection", default=False, ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 19e6cafb12..78f70b7ad3 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -10,31 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings): """ ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( - description='Aliyun OSS bucket name', + description="Aliyun OSS bucket name", default=None, ) ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( - description='Aliyun OSS access key', + description="Aliyun OSS access key", default=None, ) ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( - description='Aliyun OSS secret key', + description="Aliyun OSS secret key", default=None, ) ALIYUN_OSS_ENDPOINT: Optional[str] = Field( - description='Aliyun OSS endpoint URL', + description="Aliyun OSS endpoint URL", default=None, ) ALIYUN_OSS_REGION: Optional[str] = Field( - description='Aliyun OSS region', + description="Aliyun OSS region", default=None, ) ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( - description='Aliyun OSS authentication version', + description="Aliyun OSS authentication version", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index 2566fbd5da..bef9326108 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings): """ S3_ENDPOINT: Optional[str] = Field( - description='S3 storage endpoint', + description="S3 storage endpoint", default=None, ) S3_REGION: Optional[str] = Field( - description='S3 storage region', + description="S3 storage region", default=None, ) S3_BUCKET_NAME: Optional[str] = Field( - description='S3 storage bucket name', + description="S3 storage bucket name", default=None, ) S3_ACCESS_KEY: Optional[str] = Field( - description='S3 storage access key', + description="S3 storage access key", default=None, ) S3_SECRET_KEY: Optional[str] = Field( - description='S3 storage secret key', + description="S3 storage secret key", default=None, ) S3_ADDRESS_STYLE: str = Field( - description='S3 storage address style', - default='auto', + description="S3 storage address style", + default="auto", ) S3_USE_AWS_MANAGED_IAM: bool = Field( - description='whether to use aws managed IAM for S3', + description="whether to use aws managed IAM for S3", default=False, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index 26e441c89b..10944b58ed 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings): """ AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( - description='Azure Blob account name', + description="Azure Blob account name", default=None, ) AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( - description='Azure Blob account key', + description="Azure Blob account key", default=None, ) AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( - description='Azure Blob container name', + description="Azure Blob container name", default=None, ) AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( - description='Azure Blob account URL', + description="Azure Blob account URL", default=None, ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e1b0e34e0c..10a2d97e8d 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings): """ GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( - description='Google Cloud storage bucket name', + description="Google Cloud storage bucket name", default=None, ) GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( - description='Google Cloud storage service account json base64', + description="Google Cloud storage service account json base64", default=None, ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index 6c0c067469..f8993496c9 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings): """ OCI_ENDPOINT: Optional[str] = Field( - description='OCI storage endpoint', + description="OCI storage endpoint", default=None, ) OCI_REGION: Optional[str] = Field( - description='OCI storage region', + description="OCI storage region", default=None, ) OCI_BUCKET_NAME: Optional[str] = Field( - description='OCI storage bucket name', + description="OCI storage bucket name", default=None, ) OCI_ACCESS_KEY: Optional[str] = Field( - description='OCI storage access key', + description="OCI storage access key", default=None, ) OCI_SECRET_KEY: Optional[str] = Field( - description='OCI storage secret key', + description="OCI storage secret key", default=None, ) - diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 1060c7b93e..765ac08f3e 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings): """ TENCENT_COS_BUCKET_NAME: Optional[str] = Field( - description='Tencent Cloud COS bucket name', + description="Tencent Cloud COS bucket name", default=None, ) TENCENT_COS_REGION: Optional[str] = Field( - description='Tencent Cloud COS region', + description="Tencent Cloud COS region", default=None, ) TENCENT_COS_SECRET_ID: Optional[str] = Field( - description='Tencent Cloud COS secret id', + description="Tencent Cloud COS secret id", default=None, ) TENCENT_COS_SECRET_KEY: Optional[str] = Field( - description='Tencent Cloud COS secret key', + description="Tencent Cloud COS secret key", default=None, ) TENCENT_COS_SCHEME: Optional[str] = Field( - description='Tencent Cloud COS scheme', + description="Tencent Cloud COS scheme", default=None, ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index db2899265e..04f5b0e5bf 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel): https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID : Optional[str] = Field( - default=None, - description="The Access Key ID provided by Alibaba Cloud for authentication." + ANALYTICDB_KEY_ID: Optional[str] = Field( + default=None, description="The Access Key ID provided by Alibaba Cloud for authentication." ) - ANALYTICDB_KEY_SECRET : Optional[str] = Field( - default=None, - description="The Secret Access Key corresponding to the Access Key ID for secure access." + ANALYTICDB_KEY_SECRET: Optional[str] = Field( + default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access." ) - ANALYTICDB_REGION_ID : Optional[str] = Field( - default=None, - description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." + ANALYTICDB_REGION_ID: Optional[str] = Field( + default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." ) - ANALYTICDB_INSTANCE_ID : Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: Optional[str] = Field( default=None, - description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').." + description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..", ) - ANALYTICDB_ACCOUNT : Optional[str] = Field( - default=None, - description="The account name used to log in to the AnalyticDB instance." + ANALYTICDB_ACCOUNT: Optional[str] = Field( + default=None, description="The account name used to log in to the AnalyticDB instance." ) - ANALYTICDB_PASSWORD : Optional[str] = Field( - default=None, - description="The password associated with the AnalyticDB account for authentication." + ANALYTICDB_PASSWORD: Optional[str] = Field( + default=None, description="The password associated with the AnalyticDB account for authentication." ) - ANALYTICDB_NAMESPACE : Optional[str] = Field( - default=None, - description="The namespace within AnalyticDB for schema isolation." + ANALYTICDB_NAMESPACE: Optional[str] = Field( + default=None, description="The namespace within AnalyticDB for schema isolation." ) - ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field( - default=None, - description="The password for accessing the specified namespace within the AnalyticDB instance." + ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( + default=None, description="The password for accessing the specified namespace within the AnalyticDB instance." ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index f365879efb..d386623a56 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings): """ CHROMA_HOST: Optional[str] = Field( - description='Chroma host', + description="Chroma host", default=None, ) CHROMA_PORT: PositiveInt = Field( - description='Chroma port', + description="Chroma port", default=8000, ) CHROMA_TENANT: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_DATABASE: Optional[str] = Field( - description='Chroma database', + description="Chroma database", default=None, ) CHROMA_AUTH_PROVIDER: Optional[str] = Field( - description='Chroma authentication provider', + description="Chroma authentication provider", default=None, ) CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( - description='Chroma authentication credentials', + description="Chroma authentication credentials", default=None, ) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 01502d4590..85466cd5cc 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings): """ MILVUS_HOST: Optional[str] = Field( - description='Milvus host', + description="Milvus host", default=None, ) MILVUS_PORT: PositiveInt = Field( - description='Milvus RestFul API port', + description="Milvus RestFul API port", default=9091, ) MILVUS_USER: Optional[str] = Field( - description='Milvus user', + description="Milvus user", default=None, ) MILVUS_PASSWORD: Optional[str] = Field( - description='Milvus password', + description="Milvus password", default=None, ) MILVUS_SECURE: bool = Field( - description='whether to use SSL connection for Milvus', + description="whether to use SSL connection for Milvus", default=False, ) MILVUS_DATABASE: str = Field( - description='Milvus database, default to `default`', - default='default', + description="Milvus database, default to `default`", + default="default", ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index 895cd6f176..6451d26e1c 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field, PositiveInt @@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel): """ MYSCALE_HOST: str = Field( - description='MyScale host', - default='localhost', + description="MyScale host", + default="localhost", ) MYSCALE_PORT: PositiveInt = Field( - description='MyScale port', + description="MyScale port", default=8123, ) MYSCALE_USER: str = Field( - description='MyScale user', - default='default', + description="MyScale user", + default="default", ) MYSCALE_PASSWORD: str = Field( - description='MyScale password', - default='', + description="MyScale password", + default="", ) MYSCALE_DATABASE: str = Field( - description='MyScale database name', - default='default', + description="MyScale database name", + default="default", ) MYSCALE_FTS_PARAMS: str = Field( - description='MyScale fts index parameters', - default='', + description="MyScale fts index parameters", + default="", ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 15d6f5b6a9..5823dc1433 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings): """ OPENSEARCH_HOST: Optional[str] = Field( - description='OpenSearch host', + description="OpenSearch host", default=None, ) OPENSEARCH_PORT: PositiveInt = Field( - description='OpenSearch port', + description="OpenSearch port", default=9200, ) OPENSEARCH_USER: Optional[str] = Field( - description='OpenSearch user', + description="OpenSearch user", default=None, ) OPENSEARCH_PASSWORD: Optional[str] = Field( - description='OpenSearch password', + description="OpenSearch password", default=None, ) OPENSEARCH_SECURE: bool = Field( - description='whether to use SSL connection for OpenSearch', + description="whether to use SSL connection for OpenSearch", default=False, ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index 888fc19492..62614ae870 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -10,26 +10,26 @@ class OracleConfig(BaseSettings): """ ORACLE_HOST: Optional[str] = Field( - description='ORACLE host', + description="ORACLE host", default=None, ) ORACLE_PORT: Optional[PositiveInt] = Field( - description='ORACLE port', + description="ORACLE port", default=1521, ) ORACLE_USER: Optional[str] = Field( - description='ORACLE user', + description="ORACLE user", default=None, ) ORACLE_PASSWORD: Optional[str] = Field( - description='ORACLE password', + description="ORACLE password", default=None, ) ORACLE_DATABASE: Optional[str] = Field( - description='ORACLE database', + description="ORACLE database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 8a677f60a3..39a7c1d8d5 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings): """ PGVECTOR_HOST: Optional[str] = Field( - description='PGVector host', + description="PGVector host", default=None, ) PGVECTOR_PORT: Optional[PositiveInt] = Field( - description='PGVector port', + description="PGVector port", default=5433, ) PGVECTOR_USER: Optional[str] = Field( - description='PGVector user', + description="PGVector user", default=None, ) PGVECTOR_PASSWORD: Optional[str] = Field( - description='PGVector password', + description="PGVector password", default=None, ) PGVECTOR_DATABASE: Optional[str] = Field( - description='PGVector database', + description="PGVector database", default=None, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index 39f52f22ff..c40e5ff921 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings): """ PGVECTO_RS_HOST: Optional[str] = Field( - description='PGVectoRS host', + description="PGVectoRS host", default=None, ) PGVECTO_RS_PORT: Optional[PositiveInt] = Field( - description='PGVectoRS port', + description="PGVectoRS port", default=5431, ) PGVECTO_RS_USER: Optional[str] = Field( - description='PGVectoRS user', + description="PGVectoRS user", default=None, ) PGVECTO_RS_PASSWORD: Optional[str] = Field( - description='PGVectoRS password', + description="PGVectoRS password", default=None, ) PGVECTO_RS_DATABASE: Optional[str] = Field( - description='PGVectoRS database', + description="PGVectoRS database", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index c85bf9c7dc..27f75491c9 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings): """ QDRANT_URL: Optional[str] = Field( - description='Qdrant url', + description="Qdrant url", default=None, ) QDRANT_API_KEY: Optional[str] = Field( - description='Qdrant api key', + description="Qdrant api key", default=None, ) QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( - description='Qdrant client timeout in seconds', + description="Qdrant client timeout in seconds", default=20, ) QDRANT_GRPC_ENABLED: bool = Field( - description='whether enable grpc support for Qdrant connection', + description="whether enable grpc support for Qdrant connection", default=False, ) QDRANT_GRPC_PORT: PositiveInt = Field( - description='Qdrant grpc port', + description="Qdrant grpc port", default=6334, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index be93185f3c..66b9ecc03f 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -10,26 +10,26 @@ class RelytConfig(BaseSettings): """ RELYT_HOST: Optional[str] = Field( - description='Relyt host', + description="Relyt host", default=None, ) RELYT_PORT: PositiveInt = Field( - description='Relyt port', + description="Relyt port", default=9200, ) RELYT_USER: Optional[str] = Field( - description='Relyt user', + description="Relyt user", default=None, ) RELYT_PASSWORD: Optional[str] = Field( - description='Relyt password', + description="Relyt password", default=None, ) RELYT_DATABASE: Optional[str] = Field( - description='Relyt database', - default='default', + description="Relyt database", + default="default", ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index 531ec84068..46b4cb6a24 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings): """ TENCENT_VECTOR_DB_URL: Optional[str] = Field( - description='Tencent Vector URL', + description="Tencent Vector URL", default=None, ) TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( - description='Tencent Vector API key', + description="Tencent Vector API key", default=None, ) TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( - description='Tencent Vector timeout in seconds', + description="Tencent Vector timeout in seconds", default=30, ) TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( - description='Tencent Vector username', + description="Tencent Vector username", default=None, ) TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( - description='Tencent Vector password', + description="Tencent Vector password", default=None, ) TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( - description='Tencent Vector sharding number', + description="Tencent Vector sharding number", default=1, ) TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( - description='Tencent Vector replicas', + description="Tencent Vector replicas", default=2, ) TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( - description='Tencent Vector Database', + description="Tencent Vector Database", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index 8d459691a8..dbcb276c01 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings): """ TIDB_VECTOR_HOST: Optional[str] = Field( - description='TiDB Vector host', + description="TiDB Vector host", default=None, ) TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( - description='TiDB Vector port', + description="TiDB Vector port", default=4000, ) TIDB_VECTOR_USER: Optional[str] = Field( - description='TiDB Vector user', + description="TiDB Vector user", default=None, ) TIDB_VECTOR_PASSWORD: Optional[str] = Field( - description='TiDB Vector password', + description="TiDB Vector password", default=None, ) TIDB_VECTOR_DATABASE: Optional[str] = Field( - description='TiDB Vector database', + description="TiDB Vector database", default=None, ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index b985ecea12..63d1022f6a 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings): """ WEAVIATE_ENDPOINT: Optional[str] = Field( - description='Weaviate endpoint URL', + description="Weaviate endpoint URL", default=None, ) WEAVIATE_API_KEY: Optional[str] = Field( - description='Weaviate API key', + description="Weaviate API key", default=None, ) WEAVIATE_GRPC_ENABLED: bool = Field( - description='whether to enable gRPC for Weaviate connection', + description="whether to enable gRPC for Weaviate connection", default=True, ) WEAVIATE_BATCH_SIZE: PositiveInt = Field( - description='Weaviate batch size', + description="Weaviate batch size", default=100, ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index a7c5eb15a3..a6ab550de1 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings): """ CURRENT_VERSION: str = Field( - description='Dify version', - default='0.7.1', + description="Dify version", + default="0.7.1", ) COMMIT_SHA: str = Field( description="SHA-1 checksum of the git commit used to build the app", - default='', + default="", ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 47b638573b..6175fdbda7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -77,7 +77,6 @@ exclude = [ "services/**/*.py", "tasks/**/*.py", "tests/**/*.py", - "configs/**/*.py", ] [tool.pytest_env] From b035c02f78924020e2329de36b473189851bb409 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 23 Aug 2024 23:52:25 +0800 Subject: [PATCH 24/24] chore(api/tests): apply ruff reformat #7590 (#7591) Co-authored-by: -LAN- --- api/pyproject.toml | 1 - .../model_runtime/__mock/anthropic.py | 66 +-- .../model_runtime/__mock/google.py | 58 +- .../model_runtime/__mock/huggingface.py | 7 +- .../model_runtime/__mock/huggingface_chat.py | 27 +- .../model_runtime/__mock/huggingface_tei.py | 42 +- .../model_runtime/__mock/openai.py | 21 +- .../model_runtime/__mock/openai_chat.py | 178 +++--- .../model_runtime/__mock/openai_completion.py | 77 +-- .../model_runtime/__mock/openai_embeddings.py | 50 +- .../model_runtime/__mock/openai_moderation.py | 102 ++-- .../model_runtime/__mock/openai_remote.py | 11 +- .../__mock/openai_speech2text.py | 19 +- .../model_runtime/__mock/xinference.py | 135 +++-- .../model_runtime/anthropic/test_llm.py | 75 +-- .../model_runtime/anthropic/test_provider.py | 12 +- .../model_runtime/azure_openai/test_llm.py | 265 ++++----- .../azure_openai/test_text_embedding.py | 49 +- .../model_runtime/baichuan/test_llm.py | 132 ++--- .../model_runtime/baichuan/test_provider.py | 12 +- .../baichuan/test_text_embedding.py | 42 +- .../model_runtime/bedrock/test_llm.py | 66 +-- .../model_runtime/bedrock/test_provider.py | 6 +- .../model_runtime/chatglm/test_llm.py | 213 +++---- .../model_runtime/chatglm/test_provider.py | 14 +- .../model_runtime/cohere/test_llm.py | 179 ++---- .../model_runtime/cohere/test_provider.py | 10 +- .../model_runtime/cohere/test_rerank.py | 24 +- .../cohere/test_text_embedding.py | 38 +- .../model_runtime/google/test_llm.py | 191 +++---- .../model_runtime/google/test_provider.py | 12 +- .../model_runtime/huggingface_hub/test_llm.py | 260 ++++----- .../huggingface_hub/test_text_embedding.py | 89 ++- .../huggingface_tei/test_embeddings.py | 38 +- .../huggingface_tei/test_rerank.py | 38 +- .../model_runtime/hunyuan/test_llm.py | 69 +-- .../model_runtime/hunyuan/test_provider.py | 11 +- .../hunyuan/test_text_embedding.py | 50 +- .../model_runtime/jina/test_provider.py | 12 +- .../model_runtime/jina/test_text_embedding.py | 32 +- .../model_runtime/localai/test_embedding.py | 6 +- .../model_runtime/localai/test_llm.py | 154 ++---- .../model_runtime/localai/test_rerank.py | 56 +- .../model_runtime/localai/test_speech2text.py | 30 +- .../model_runtime/minimax/test_embedding.py | 41 +- .../model_runtime/minimax/test_llm.py | 107 ++-- .../model_runtime/minimax/test_provider.py | 8 +- .../model_runtime/novita/test_llm.py | 70 +-- .../model_runtime/novita/test_provider.py | 6 +- .../model_runtime/ollama/test_llm.py | 190 +++---- .../ollama/test_text_embedding.py | 48 +- .../model_runtime/openai/test_llm.py | 294 ++++------ .../model_runtime/openai/test_moderation.py | 33 +- .../model_runtime/openai/test_provider.py | 12 +- .../model_runtime/openai/test_speech2text.py | 35 +- .../openai/test_text_embedding.py | 45 +- .../openai_api_compatible/test_llm.py | 149 ++--- .../openai_api_compatible/test_speech2text.py | 17 +- .../test_text_embedding.py | 52 +- .../model_runtime/openllm/test_embedding.py | 33 +- .../model_runtime/openllm/test_llm.py | 69 +-- .../model_runtime/openrouter/test_llm.py | 71 +-- .../model_runtime/replicate/test_llm.py | 74 ++- .../replicate/test_text_embedding.py | 95 ++-- .../model_runtime/sagemaker/test_provider.py | 8 +- .../model_runtime/sagemaker/test_rerank.py | 12 +- .../sagemaker/test_text_embedding.py | 32 +- .../model_runtime/siliconflow/test_llm.py | 69 +-- .../siliconflow/test_provider.py | 10 +- .../model_runtime/siliconflow/test_rerank.py | 12 +- .../siliconflow/test_speech2text.py | 16 +- .../siliconflow/test_text_embedding.py | 4 +- .../model_runtime/spark/test_llm.py | 77 +-- .../model_runtime/spark/test_provider.py | 10 +- .../model_runtime/stepfun/test_llm.py | 121 ++-- .../test_model_provider_factory.py | 29 +- .../model_runtime/togetherai/test_llm.py | 74 +-- .../model_runtime/tongyi/test_llm.py | 67 +-- .../model_runtime/tongyi/test_provider.py | 8 +- .../tongyi/test_response_format.py | 18 +- .../model_runtime/upstage/test_llm.py | 177 ++---- .../model_runtime/upstage/test_provider.py | 12 +- .../upstage/test_text_embedding.py | 37 +- .../volcengine_maas/test_embedding.py | 70 ++- .../model_runtime/volcengine_maas/test_llm.py | 113 ++-- .../model_runtime/wenxin/test_embedding.py | 44 +- .../model_runtime/wenxin/test_llm.py | 195 +++---- .../model_runtime/wenxin/test_provider.py | 12 +- .../xinference/test_embeddings.py | 46 +- .../model_runtime/xinference/test_llm.py | 201 +++---- .../model_runtime/xinference/test_rerank.py | 32 +- .../model_runtime/zhinao/test_llm.py | 69 +-- .../model_runtime/zhinao/test_provider.py | 10 +- .../model_runtime/zhipuai/test_llm.py | 106 +--- .../model_runtime/zhipuai/test_provider.py | 10 +- .../zhipuai/test_text_embedding.py | 36 +- .../integration_tests/tools/__mock/http.py | 15 +- .../tools/__mock_server/openapi_todo.py | 6 +- .../tools/api_tool/test_api_tool.py | 51 +- .../tools/test_all_provider.py | 9 +- .../integration_tests/utils/parent_class.py | 2 +- .../utils/test_module_import_helper.py | 24 +- .../vdb/__mock/tcvectordb.py | 143 +++-- .../vdb/analyticdb/test_analyticdb.py | 5 +- .../vdb/chroma/test_chroma.py | 4 +- .../vdb/elasticsearch/test_elasticsearch.py | 11 +- .../vdb/milvus/test_milvus.py | 10 +- .../vdb/myscale/test_myscale.py | 2 +- .../vdb/opensearch/test_opensearch.py | 109 ++-- .../vdb/pgvecto_rs/test_pgvecto_rs.py | 13 +- .../vdb/qdrant/test_qdrant.py | 8 +- .../vdb/tcvectordb/test_tencent.py | 28 +- .../vdb/test_vector_store.py | 14 +- .../vdb/tidb_vector/test_tidb_vector.py | 14 +- .../vdb/weaviate/test_weaviate.py | 8 +- .../workflow/nodes/__mock/code_executor.py | 15 +- .../workflow/nodes/__mock/http.py | 30 +- .../nodes/code_executor/test_code_executor.py | 6 +- .../code_executor/test_code_javascript.py | 15 +- .../nodes/code_executor/test_code_jinja2.py | 25 +- .../nodes/code_executor/test_code_python3.py | 11 +- .../workflow/nodes/test_code.py | 278 +++++----- .../workflow/nodes/test_http.py | 519 +++++++++--------- .../workflow/nodes/test_llm.py | 236 ++++---- .../nodes/test_parameter_extractor.py | 499 ++++++++--------- .../workflow/nodes/test_template_transform.py | 45 +- .../workflow/nodes/test_tool.py | 101 ++-- .../unit_tests/configs/test_dify_config.py | 57 +- .../core/app/segments/test_factory.py | 98 ++-- .../core/app/segments/test_segment.py | 26 +- .../core/app/segments/test_variables.py | 50 +- .../unit_tests/core/helper/test_ssrf_proxy.py | 14 +- .../wenxin/test_text_embedding.py | 20 +- .../prompt/test_advanced_prompt_transform.py | 114 ++-- .../test_agent_history_prompt_transform.py | 38 +- .../core/prompt/test_prompt_transform.py | 26 +- .../prompt/test_simple_prompt_transform.py | 133 +++-- .../rag/datasource/vdb/milvus/test_milvus.py | 11 +- .../rag/extractor/firecrawl/test_firecrawl.py | 12 +- .../rag/extractor/test_notion_extractor.py | 57 +- .../unit_tests/core/test_model_manager.py | 35 +- .../unit_tests/core/test_provider_manager.py | 200 +++---- .../tools/test_tool_parameter_converter.py | 48 +- .../core/workflow/nodes/test_answer.py | 37 +- .../core/workflow/nodes/test_if_else.py | 232 +++----- .../workflow/nodes/test_variable_assigner.py | 98 ++-- api/tests/unit_tests/libs/test_pandas.py | 40 +- api/tests/unit_tests/libs/test_rsa.py | 2 +- api/tests/unit_tests/libs/test_yarl.py | 26 +- api/tests/unit_tests/models/test_account.py | 10 +- .../models/test_conversation_variable.py | 14 +- api/tests/unit_tests/models/test_workflow.py | 94 ++-- .../workflow/test_workflow_converter.py | 187 +++---- .../position_helper/test_position_helper.py | 66 +-- .../unit_tests/utils/yaml/test_yaml_utils.py | 38 +- 155 files changed, 4279 insertions(+), 5925 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 6175fdbda7..e05a51dc13 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -76,7 +76,6 @@ exclude = [ "migrations/**/*", "services/**/*.py", "tasks/**/*.py", - "tests/**/*.py", ] [tool.pytest_env] diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 3326f874b0..79a3dc0394 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -22,23 +22,20 @@ from anthropic.types import ( ) from anthropic.types.message_delta_event import Delta -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( - id='msg-123', - type='message', - role='assistant', - content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], + id="msg-123", + type="message", + role="assistant", + content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")], model=model, - stop_reason='stop_sequence', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) + stop_reason="stop_sequence", + usage=Usage(input_tokens=1, output_tokens=1), ) @staticmethod @@ -46,52 +43,43 @@ class MockAnthropicClass: full_response_text = "hello, I'm a chatbot from anthropic" yield MessageStartEvent( - type='message_start', + type="message_start", message=Message( - id='msg-123', + id="msg-123", content=[], - role='assistant', + role="assistant", model=model, stop_reason=None, - type='message', - usage=Usage( - input_tokens=1, - output_tokens=1 - ) - ) + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ), ) index = 0 for i in range(0, len(full_response_text)): yield ContentBlockDeltaEvent( - type='content_block_delta', - delta=TextDelta(text=full_response_text[i], type='text_delta'), - index=index + type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index ) index += 1 yield MessageDeltaEvent( - type='message_delta', - delta=Delta( - stop_reason='stop_sequence' - ), - usage=MessageDeltaUsage( - output_tokens=1 - ) + type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1) ) - yield MessageStopEvent(type='message_stop') + yield MessageStopEvent(type="message_stop") - def mocked_anthropic(self: Messages, *, - max_tokens: int, - messages: Iterable[MessageParam], - model: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Message, Stream[MessageStreamEvent]]: + def mocked_anthropic( + self: Messages, + *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any, + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: - raise anthropic.AuthenticationError('Invalid API key') + raise anthropic.AuthenticationError("Invalid API key") if stream: return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model) @@ -102,7 +90,7 @@ class MockAnthropicClass: @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic) yield diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index d838e9890f..bc0684086f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse -current_api_key = '' +current_api_key = "" + class MockGoogleResponseClass: _done = False def __iter__(self): - full_response_text = 'it\'s google!' + full_response_text = "it's google!" for i in range(0, len(full_response_text) + 1, 1): if i == len(full_response_text): self._done = True yield GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) else: yield GenerateContentResponse( - done=False, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] + done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[] ) + class MockGoogleResponseCandidateClass: - finish_reason = 'stop' + finish_reason = "stop" @property def content(self) -> gag_content.Content: - return gag_content.Content( - parts=[ - gag_content.Part(text='it\'s google!') - ] - ) + return gag_content.Content(parts=[gag_content.Part(text="it's google!")]) + class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: - return GenerateContentResponse( - done=True, - iterator=None, - result=glm.GenerateContentResponse({ - - }), - chunks=[] - ) + return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: return MockGoogleResponseClass() - def generate_content(self: GenerativeModel, + def generate_content( + self: GenerativeModel, contents: content_types.ContentsType, *, generation_config: generation_config_types.GenerationConfigType | None = None, @@ -79,21 +62,21 @@ class MockGoogleClass: global current_api_key if len(current_api_key) < 16: - raise Exception('Invalid API key') + raise Exception("Invalid API key") if stream: return MockGoogleClass.generate_content_stream() - + return MockGoogleClass.generate_content_sync() - + @property def generative_response_text(self) -> str: - return 'it\'s google!' - + return "it's google!" + @property def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] - + def make_client(self: _ClientManager, name: str): global current_api_key @@ -121,7 +104,8 @@ class MockGoogleClass: if not self.default_metadata: return client - + + @pytest.fixture def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) @@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): yield - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index a75b058d92..97038ef596 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation) - + yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 1607624c3c..9ee76c935c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -22,10 +22,8 @@ class MockHuggingfaceChatClass: details=Details( finish_reason="length", generated_tokens=6, - tokens=[ - Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6) - ] - ) + tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], + ), ) return response @@ -36,26 +34,23 @@ class MockHuggingfaceChatClass: for i in range(0, len(full_text)): response = TextGenerationStreamResponse( - token = Token(id=i, text=full_text[i], logprob=0.0, special=False), + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] - response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1) + response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) yield response - def text_generation(self: InferenceClient, prompt: str, *, - stream: Literal[False] = ..., - model: Optional[str] = None, - **kwargs: Any + def text_generation( + self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: # check if key is valid - if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']): - raise BadRequestError('Invalid API key') - + if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): + raise BadRequestError("Invalid API key") + if model is None: - raise BadRequestError('Invalid model') - + raise BadRequestError("Invalid model") + if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model) - diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index c2fe95974b..b37b109eba 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -5,10 +5,10 @@ class MockTEIClass: @staticmethod def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: # During mock, we don't have a real server to query, so we just return a dummy value - if 'rerank' in model_name: - model_type = 'reranker' + if "rerank" in model_name: + model_type = "reranker" else: - model_type = 'embedding' + model_type = "embedding" return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) @@ -17,16 +17,16 @@ class MockTEIClass: # Use space as token separator, and split the text into tokens tokenized_texts = [] for text in texts: - tokens = text.split(' ') + tokens = text.split(" ") current_index = 0 tokenized_text = [] for idx, token in enumerate(tokens): s_token = { - 'id': idx, - 'text': token, - 'special': False, - 'start': current_index, - 'stop': current_index + len(token), + "id": idx, + "text": token, + "special": False, + "start": current_index, + "stop": current_index + len(token), } current_index += len(token) + 1 tokenized_text.append(s_token) @@ -55,18 +55,18 @@ class MockTEIClass: embedding = [0.1] * 768 embeddings.append( { - 'object': 'embedding', - 'embedding': embedding, - 'index': idx, + "object": "embedding", + "embedding": embedding, + "index": idx, } ) return { - 'object': 'list', - 'data': embeddings, - 'model': 'MODEL_NAME', - 'usage': { - 'prompt_tokens': sum(len(text.split(' ')) for text in texts), - 'total_tokens': sum(len(text.split(' ')) for text in texts), + "object": "list", + "data": embeddings, + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": sum(len(text.split(" ")) for text in texts), + "total_tokens": sum(len(text.split(" ")) for text in texts), }, } @@ -83,9 +83,9 @@ class MockTEIClass: for idx, text in enumerate(texts): reranked_docs.append( { - 'index': idx, - 'text': text, - 'score': 0.9, + "index": idx, + "text": text, + "score": 0.9, } ) # For mock, only return the first document diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 0d3f0fbbea..6637f4f212 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai( + monkeypatch: MonkeyPatch, + methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]], +) -> Callable[[], None]: """ - mock openai module + mock openai module - :param monkeypatch: pytest monkeypatch fixture - :return: unpatch function + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function """ + def unpatch() -> None: monkeypatch.undo() @@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c return unpatch -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_openai_mock(request, monkeypatch): - methods = request.param if hasattr(request, 'param') else [] + methods = request.param if hasattr(request, "param") else [] if MOCK: unpatch = mock_openai(monkeypatch, methods=methods) - + yield if MOCK: - unpatch() \ No newline at end of file + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index ba902e32ea..d9cd7b046e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -43,62 +43,64 @@ class MockChatClass: if not functions or len(functions) == 0: return None function: completion_create_params.Function = functions[0] - function_name = function['name'] - function_description = function['description'] - function_parameters = function['parameters'] - function_parameters_type = function_parameters['type'] - if function_parameters_type != 'object': + function_name = function["name"] + function_description = function["description"] + function_parameters = function["parameters"] + function_parameters_type = function_parameters["type"] + if function_parameters_type != "object": return None - function_parameters_properties = function_parameters['properties'] - function_parameters_required = function_parameters['required'] + function_parameters_properties = function_parameters["properties"] + function_parameters_required = function_parameters["required"] parameters = {} for parameter_name, parameter in function_parameters_properties.items(): if parameter_name not in function_parameters_required: continue - parameter_type = parameter['type'] - if parameter_type == 'string': - if 'enum' in parameter: - if len(parameter['enum']) == 0: + parameter_type = parameter["type"] + if parameter_type == "string": + if "enum" in parameter: + if len(parameter["enum"]) == 0: continue - parameters[parameter_name] = parameter['enum'][0] + parameters[parameter_name] = parameter["enum"][0] else: - parameters[parameter_name] = 'kawaii' - elif parameter_type == 'integer': + parameters[parameter_name] = "kawaii" + elif parameter_type == "integer": parameters[parameter_name] = 114514 - elif parameter_type == 'number': + elif parameter_type == "number": parameters[parameter_name] = 1919810.0 - elif parameter_type == 'boolean': + elif parameter_type == "boolean": parameters[parameter_name] = True return FunctionCall(name=function_name, arguments=dumps(parameters)) - + @staticmethod - def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: + def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None tool = tools[0] - if 'type' in tools and tools['type'] != 'function': + if "type" in tools and tools["type"] != "function": return None - function = tool['function'] + function = tool["function"] function_call = MockChatClass.generate_function_call(functions=[function]) if function_call is None: return None - - list_tool_calls.append(ChatCompletionMessageToolCall( - id='sakurajima-mai', - function=Function( - name=function_call.name, - arguments=function_call.arguments, - ), - type='function' - )) + + list_tool_calls.append( + ChatCompletionMessageToolCall( + id="sakurajima-mai", + function=Function( + name=function_call.name, + arguments=function_call.arguments, + ), + type="function", + ) + ) return list_tool_calls - + @staticmethod def mocked_openai_chat_create_sync( model: str, @@ -111,30 +113,27 @@ class MockChatClass: tool_calls = MockChatClass.generate_tool_calls(tools=tools) return _ChatCompletion( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ _ChatCompletionChoice( - finish_reason='content_filter', + finish_reason="content_filter", index=0, message=ChatCompletionMessage( - content='elaina', - role='assistant', - function_call=function_call, - tool_calls=tool_calls - ) + content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls + ), ) ], created=int(time()), model=model, - object='chat.completion', - system_fingerprint='', + object="chat.completion", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod def mocked_openai_chat_create_stream( model: str, @@ -150,36 +149,40 @@ class MockChatClass: for i in range(0, len(full_text) + 1): if i == len(full_text): yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( - content='', + content="", function_call=ChoiceDeltaFunctionCall( name=function_call.name, arguments=function_call.arguments, - ) if function_call else None, - role='assistant', + ) + if function_call + else None, + role="assistant", tool_calls=[ ChoiceDeltaToolCall( index=0, - id='misaka-mikoto', + id="misaka-mikoto", function=ChoiceDeltaToolCallFunction( name=tool_calls[0].function.name, arguments=tool_calls[0].function.arguments, ), - type='function' + type="function", ) - ] if tool_calls and len(tool_calls) > 0 else None + ] + if tool_calls and len(tool_calls) > 0 + else None, ), - finish_reason='function_call', + finish_reason="function_call", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", usage=CompletionUsage( prompt_tokens=2, completion_tokens=17, @@ -188,30 +191,45 @@ class MockChatClass: ) else: yield ChatCompletionChunk( - id='cmpl-3QJQa5jXJ5Z5X', + id="cmpl-3QJQa5jXJ5Z5X", choices=[ Choice( delta=ChoiceDelta( content=full_text[i], - role='assistant', + role="assistant", ), - finish_reason='content_filter', + finish_reason="content_filter", index=0, ) ], created=int(time()), model=model, - object='chat.completion.chunk', - system_fingerprint='', + object="chat.completion.chunk", + system_fingerprint="", ) - def chat_create(self: Completions, *, + def chat_create( + self: Completions, + *, messages: list[ChatCompletionMessageParam], - model: Union[str,Literal[ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], + model: Union[ + str, + Literal[ + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + ], ], functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, @@ -220,24 +238,32 @@ class MockChatClass: **kwargs: Any, ): openai_models = [ - "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", - "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", - "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", ] - azure_openai_models = [ - "gpt35", "gpt-4v", "gpt-35-turbo" - ] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if stream: return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) - - return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) \ No newline at end of file + + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index b0d2675905..c27e89248f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockCompletionsClass: @staticmethod - def mocked_openai_completion_create_sync( - model: str - ) -> CompletionMessage: + def mocked_openai_completion_create_sync(model: str) -> CompletionMessage: return CompletionMessage( id="cmpl-3QJQa5jXJ5Z5X", object="text_completion", @@ -38,13 +36,11 @@ class MockCompletionsClass: prompt_tokens=2, completion_tokens=1, total_tokens=3, - ) + ), ) - + @staticmethod - def mocked_openai_completion_create_stream( - model: str - ) -> Generator[CompletionMessage, None, None]: + def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]: full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```" for i in range(0, len(full_text) + 1): if i == len(full_text): @@ -76,46 +72,59 @@ class MockCompletionsClass: model=model, system_fingerprint="", choices=[ - CompletionChoice( - text=full_text[i], - index=0, - logprobs=None, - finish_reason="content_filter" - ) + CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter") ], ) - def completion_create(self: Completions, *, model: Union[ - str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", - "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", - "text-ada-001"], + def completion_create( + self: Completions, + *, + model: Union[ + str, + Literal[ + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + ], ], prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ): openai_models = [ - "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001", - "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", - ] - azure_openai_models = [ - "gpt-35-turbo-instruct" + "babbage-002", + "davinci-002", + "gpt-3.5-turbo-instruct", + "text-davinci-003", + "text-davinci-002", + "text-davinci-001", + "code-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", ] + azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: - if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI: + if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: # sometime, provider use OpenAI compatible API will not have api key or have different api key format # so we only check if model is in openai_models - raise InvokeAuthorizationError('Invalid api key') + raise InvokeAuthorizationError("Invalid api key") if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI: - raise InvokeAuthorizationError('Invalid api key') - + raise InvokeAuthorizationError("Invalid api key") + if not prompt: - raise BadRequestError('Invalid prompt') + raise BadRequestError("Invalid prompt") if stream: return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) - - return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) \ No newline at end of file + + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index eccdbd3479..4138cdd40d 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -12,48 +12,39 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockEmbeddingsClass: def create_embeddings( - self: Embeddings, *, + self: Embeddings, + *, input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> CreateEmbeddingResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - if encoding_format == 'float': + raise InvokeAuthorizationError("Invalid API key") + + if encoding_format == "float": return CreateEmbeddingResponse( data=[ - Embedding( - embedding=[0.23333 for _ in range(233)], - index=i, - object='embedding' - ) for i in range(len(input)) + Embedding(embedding=[0.23333 for _ in range(233)], index=i, object="embedding") + for i in range(len(input)) ], model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) + usage=Usage(prompt_tokens=2, total_tokens=2), ) - - embeddings = 'VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7' + + embeddings = "VEfNvMLUnrwFleO8hcj9vEE/yrzyjOA84E1MvNfoCrxjrI+8sZUKvNgrBT17uY07gJ/IvNvhHLrUemc8KXXGumalIT3YKwU7ZsnbPMhATrwTt6u8JEwRPNMmCjxGREW7TRKvu6/MG7zAyDU8wXLkuuMDZDsXsL28zHzaOw0IArzOiMO8LtASvPKM4Dul5l+80V0bPGVDZ7wYNrI89ucsvJZdYztzRm+8P8ysOyGbc7zrdgK9sdiEPKQ8sbulKdq7KIgdvKIMDj25dNc8k0AXPBn/oLzrdgK8IXe5uz0Dvrt50V68tTjLO4ZOcjoG9x29oGfZufiwmzwMDXy8EL6ZPHvdx7nKjzE8+LCbPG22hTs3EZq7TM+0POrRzTxVZo084wPkO8Nak7z8cpw8pDwxvA2T8LvBC7C72fltvC8Atjp3fYE8JHDLvEYgC7xAdls8YiabPPkEeTzPUbK8gOLCPEBSIbyt5Oy8CpreusNakzywUhA824vLPHRlr7zAhTs7IZtzvHd9AT2xY/O6ok8IvOihqrql5l88K4EvuknWorvYKwW9iXkbvGMTRLw5qPG7onPCPLgNIzwAbK67ftbZPMxYILvAyDW9TLB0vIid1buzCKi7u+d0u8iDSLxNVam8PZyJPNxnETvVANw8Oi5mu9nVszzl65I7DIKNvLGVirxsMJE7tPXQu2PvCT1zRm87p1l9uyRMkbsdfqe8U52ePHRlr7wt9Mw8/C8ivTu02rwJFGq8tpoFPWnC7blWumq7sfy+vG1zCzy9Nlg8iv+PuvxT3DuLU228kVhoOkmTqDrv1kg8ocmTu1WpBzsKml48DzglvI8ECzxwTd27I+pWvIWkQ7xUR007GqlPPBFEDrzGECu865q8PI7BkDwNxYc8tgG6ullMSLsIajs84lk1PNLjD70mv648ZmInO2tnIjzvb5Q8o5KCPLo9xrwKMyq9QqGEvI8ECzxO2508ATUdPRAlTry5kxc8KVGMPJyBHjxIUC476KGqvIU9DzwX87c88PUIParrWrzdlzS/G3K+uzEw2TxB2BU86AhfPAMiRj2dK808a85WPPCft7xU4Bg95Q9NPDxZjzwrpek7yNkZvHa0EjyQ0nM6Nq9fuyjvUbsRq8I7CAMHO3VSWLyuauE7U1qkvPkEeTxs7ZY7B6FMO48Eizy75/S7ieBPvB07rTxmyVu8onPCO5rc6Tu7XIa7oEMfPYngT7u24vk7/+W5PE8eGDxJ1iI9t4cuvBGHiLyH1GY7jfghu+oUSDwa7Mk7iXmbuut2grrq8I2563v8uyofdTxRTrs44lm1vMeWnzukf6s7r4khvEKhhDyhyZO8G5Z4Oy56wTz4sBs81Zknuz3fg7wnJuO74n1vvASEADu98128gUl3vBtyvrtZCU47yep8u5FYaDx2G0e8a85WO5cmUjz3kds8qgqbPCUaerx50d67WKIZPI7BkDua3Om74vKAvL3zXbzXpRA9CI51vLo9xryKzXg7tXtFO9RWLTwnJuM854LqPEIs8zuO5cq8d8V1u9P0cjrQ++C8cGwdPDdUlLoOGeW8auEtu8Z337nlzFK8aRg/vFCkDD0nRSM879bIvKUFID1iStU8EL6ZvLufgLtKgNE7KVEMvJOnSzwahRU895HbvJiIjLvc8n88bmC0PPLP2rywM9C7jTscOoS3mjy/Znu7dhvHuu5Q1Dyq61o6CI71u09hkry0jhw8gb6IPI8EC7uoVAM8gs9rvGM3fjx2G8e81FYtu/ojubyYRRK72Riuu83elDtNNmk70/TyuzUFsbvgKZI7onNCvAehzLumr8679R6+urr6SztX2So8Bl5SOwSEgLv5NpA8LwC2PGPvibzJ6vw7H2tQvOtXwrzXpRC8j0z/uxwcbTy2vr+8VWYNu+t2ArwKmt68NKN2O3XrIzw9A747UU47vaavzjwU+qW8YBqyvE02aTyEt5o8cCmjOxtyPrxs7ZY775NOu+SJWLxMJQY8/bWWu6IMDrzSSsQ7GSPbPLlQnbpVzcE7Pka4PJ96sLycxJg8v/9GPO2HZTyeW3C8Vpawtx2iYTwWBg87/qI/OviwGzxyWcY7M9WNPIA4FD32C2e8tNGWPJ43trxCoYS8FGHavItTbbu7n4C80NemPLm30Ty1OMu7vG1pvG3aPztBP0o75Q/NPJhFEj2V9i683PL/O97+aLz6iu27cdPRum/mKLwvVgc89fqDu3LA+jvm2Ls8mVZ1PIuFBD3ZGK47Cpreut7+aLziWTU8XSEgPMvSKzzO73e5040+vBlmVTxS1K+8mQ4BPZZ8o7w8FpW6OR0DPSSPCz21Vwu99fqDOjMYiDy7XAY8oYaZO+aVwTyX49c84OaXOqdZfTunEQk7B8AMvMDs7zo/D6e8OP5CvN9gIzwNCII8FefOPE026TpzIjU8XsvOO+J9b7rkIiQ8is34O+e0AbxBpv67hcj9uiPq1jtCoQQ8JfY/u86nAz0Wkf28LnrBPJlW9Tt8P4K7BbSjO9grhbyAOJS8G3K+vJLe3LzXpZA7NQUxPJs+JDz6vAS8QHZbvYNVYDrj3yk88PWIPOJ97zuSIVc8ZUPnPMqPsbx2cZi7QfzPOxYGDz2hqtO6H2tQO543NjyFPY+7JRUAOt0wgDyJeZu8MpKTu6AApTtg1ze82JI5vKllZjvrV0I7HX6nu7vndDxg1ze8jwQLu1ZTNjuJvBU7BXGpvAP+C7xJk6g8j2u/vBABlLzlqBi8M9WNutRWLTx0zGM9sHbKPLoZDDtmyVu8tpqFOvPumjyuRqe87lBUvFU0drxs7Za8ejMZOzJPGbyC7qu863v8PDPVjTxJ1iI7Ca01PLuAQLuNHFy7At9LOwP+i7tYxlO80NemO9elkDx45LU8h9TmuzxZjzz/5bk8p84OurvndLwAkGi7XL9luCSzRTwMgg08vrxMPKIwyDwdomG8K6VpPGPvCTxkmTi7M/lHPGxUSzxwKSM8wQuwvOqtkzrLFSa8SbdivAMixjw2r9+7xWt2vAyCDT1NEi87B8CMvG1zi7xpwm27MrbNO9R6Z7xJt+K7jNnhu9ZiFrve/ug55CKkvCwHJLqsOr47+ortvPwvIr2v8NW8YmmVOE+FTLywUhA8MTBZvMiDyLtx8hG8OEE9vMDsbzroCF88DelBOobnPbx+b6U8sbnEOywr3ro93wO9dMzjup2xwbwnRaO7cRZMu8Z337vS44+7VpYwvFWphzxKgNE8L1aHPLPFLbunzo66zFggPN+jHbs7tFo8nW7HO9JKRLyoeD28Fm1DPGZip7u5dNe7KMsXvFnlkzxQpAw7MrZNPHpX0zwSyoK7ayQovPR0Dz3gClK8/juLPDjaCLvqrZO7a4vcO9HEzzvife88KKzXvDmocbwpMkw7t2huvaIMjjznguo7Gy/EOzxZjzoLuZ48qi5VvCjLFzuDmNo654LquyrXgDy7XAa8e7mNvJ7QAb0Rq8K7ojBIvBN0MTuOfha8GoUVveb89bxMsHS8jV9WPPKM4LyAOJS8me9AvZv7qbsbcr47tuL5uaXmXzweKNa7rkYnPINV4Lxcv+W8tVcLvI8oxbzvbxS7oYaZu9+jHT0cHO08c7uAPCSzRTywUhA85xu2u+wBcTuJvJU8PBYVusTghzsnAim8acJtPFQE0zzFIwI9C7meO1DIRry7XAY8MKpkPJZd47suN0e5JTm6u6BDn7zfx1e8AJDoOr9CQbwaQps7x/1TPLTRFryqLtU8JybjPIXI/Tz6I7k6mVb1PMWKNryd1fs8Ok0mPHt2kzy9Ep48TTZpvPS3ibwGOpi8Ns4fPBqFlbr3Kqc8+QR5vHLA+rt7uY289YXyPI6iULxL4gu8Tv/XuycCKbwCnFG8C7kevVG1b7zIXw68GoWVO4rNeDnrM4i8MxgIPUNLs7zSoJW86ScfO+rRzbs6Cqw8NxGautP0cjw0wjY8CGq7vAkU6rxKgNG5+uA+vJXXbrwKM6o86vCNOu+yjjoQAZS8xATCOQVxKbynzo68wxcZvMhATjzS4488ArsRvNEaobwRh4i7t4euvAvd2DwnAik8UtQvvBFEDrz4sJs79gtnvOknnzy+vEy8D3sfPLH8vjzmLo28KVGMvOtXwjvpapm8HBxtPH3K8Lu753Q8/l9FvLvn9DomoG48fET8u9zy/7wMpke8zmQJu3oU2TzlD828KteAPAwNfLu+mBI5ldduPNZDVjq+vEy8eEvqvDHJpLwUPaC6qi7VPABsLjwFcSm72sJcu+bYO7v41NW8RiALvYB7DjzL0is7qLs3us1FSbzaf2K8MnNTuxABFDzF8Wo838fXvOBNzDzre3w8afQEvQE1nbulBaC78zEVvG5B9LzH/VM82Riuuwu5nrwsByQ8Y6yPvHXro7yQ0nM8nStNPJkyOzwnJmM80m7+O1VmjTzqrZM8dhvHOyAQBbz3baG8KTJMPOlqmbxsVEs8Pq3suy56QbzUVq08X3CDvAE1nTwUHuA7hue9vF8tCbvwOAO6F7A9ugd9kryqLtW7auEtu9ONPryPa7+8o9r2O570OzyFpEO8ntCBPOqtk7sykhO7lC1AOw2TcLswhiq6vx4HvP5fRbwuesG7Mk8ZvA4Z5TlfcAM9DrIwPL//xrzMm5q8JEwRPHBsnbxL4gu8jyjFu99gozrkZZ483GeRPLuAwDuYiIw8iv8PvK5Gpzx+b6W87Yflu3NGbzyE+hQ8a4tcPItT7bsoy5e8L1YHvWQyBDwrga86kPEzvBQ9oDxtl0W8lwKYvGpIYrxQ5wY8AJDovOLyALyw3f489JjJvMdTpTkKMyo8V9mqvH3K8LpyNYy8JHDLOixu2LpQ54Y8Q0uzu8LUnrs0wrY84vIAveihqjwfihA8DIKNvLDd/jywM1C7FB7gOxsLirxAUqE7sulnvH3K8DkAkGg8jsGQvO+TzrynWf287CCxvK4Drbwg8UQ8JRr6vFEqAbskjwu76q2TPNP0cjopDhK8dVJYvFIXKrxLn5G8AK8oPAb3HbxbOXE8Bvedun5Q5ThHyjk8QdiVvBXDlLw0o/Y7aLGKupkOgTxKPdc81kNWPtUAXLxUR827X1FDPf47izxsEVE8akhiPIhaWzxYX5+7hT0PPSrXgLxQC0E8i4WEvKUp2jtCLHM8DcWHO768zLxnK5a89R6+vH9czrorpem73h0pvAnwr7yKzXi8gDgUPf47Czq9zyO8728UOf34EDy6PUY76OSkvKZIGr2ZDgE8gzEmPG3av7v77Ce7/oP/O3MiNTtas/w8x1OlO/D1CDvDfs27ll1jO2Ufrbv1hXK8WINZuxN0sbuxlYq8OYS3uia/rjyiTwi9O7TaO+/WyDyiDA49E7erO3fF9bj6I7k7qHi9O3SoKbyBSfc7drSSvGPvCT2pQay7t2huPGnC7byUCQY8CEaBu6rHoDhx8hE8/fgQvCjLl7zdeHS8x/3TO0Isc7tas3y8jwQLvUKhhDz+foU8fCDCPC+ZgTywD5Y7ZR8tOla66rtCCLm8gWg3vDoKrLxbWDE76SefPBkj2zrlqJi7pebfuv6Df7zWQ9a7lHA6PGDXtzzMv1Q8mtxpOwJ4lzxKGZ28mGnMPDw6z7yxY/O7m2Leu7juYjwvVge8zFigPGpIYjtWumo5xs2wOgyCjbxrZ6K8bbaFvKzTCbsks8W7C7mePIU9DzxQyEY8posUvAW0ozrHlh88CyBTPJRwursxySQ757SBuqcRCbwNCIK8EL6ZvIG+iLsIRgE8rF74vOJZtbuUcDq8r/DVPMpMt7sL3Vi8eWqquww/kzqj2vY5auGtu85kiTwMPxM66KGqvBIxNzuwUpA8v2b7u09C0rx7ms08NUirvFYQPLxKPdc68mimvP5fRTtoPPm7XuqOOgOJ+jxfLYm7u58AvXz8B72PR4W6ldfuuys+tbvYKwW7pkiaPLB2SjvKj7G875POvA6yML7qFEg9Eu68O6Up2rz77Kc84CmSPP6ivzz4sJu6/C+iOaUpWjwq14A84E3MOYB7Dr2d1Xu775NOvC6e+7spUYw8PzPhO5TGizt29ww9yNkZPY7lyrz020M7QRsQu3z8BzwkCZe79YXyO8jZmTzvGUM8HgQcO9kYrrzxBmy8hLeaPLYBOjz+oj88flBlO6GqUzuiMMi8fxlUvCr7ujz41NU8DA38PBeMAzx7uY28TTZpvFG1bzxtc4s89ucsPEereTwfipC82p4iPKtNFbzo5KQ7pcKlOW5gtDzO73c7B6FMOzRbgjxCXoo8v0JBOSl1RrwxDJ+7XWSaPD3Aw7sOsjA8tuJ5vKw6Pry5k5c8ZUNnvG/H6DyVTAA8Shkdvd7+aDvtpiW9qUGsPFTgmDwbcr68TTbpO1DnhryNX9a7mrivvIqpPjxsqhy81HrnOzv31Dvth+U6UtQvPBz4MrvtpqW84OYXvRz4sjxwkFe8zSGPuycCqbyFPY8818nKOw84JTy8bWk8USqBvBGHiLtosQo8BOs0u9skl7xQ54Y8uvrLPOknn7w705o8Jny0PAd9EjxhoKa8Iv2tu2M3/jtsVEs8DcUHPQSEADs3eE48GkKbupRR+rvdeHQ7Xy2JvO1jKz0xMFm8sWPzux07LbyrTZW7bdq/O6Pa9r0ahRW9CyDTOjSjdjyQ8bO8yaIIPfupLTz/CfQ7xndfvJs+JD0zPEK8KO/RvMpw8bwObzY7fm+lPJtiXrz5BHm8WmsIvKlBrLuDdKA7hWHJOgd9Ers0o/Y7nlvwu5NAl7u8BrW6utYRO2SZuDxyNYw8CppevAY6GDxVqQe9oGdZPFa6ary3RLS70NcmO2PQSb36ZrM86q2TPML42LwewaE8k2RRPDmocTsi/S29o/k2PHRlr7zjnC+8gHsOPUpcFzxtl8W6tuL5vHw/gry/2wy9yaIIvINV4Dx3fQG7ISFoPO7pnzwGXlK8HPiyPGAaMjzBC7A7MQyfu+eC6jyV1+67pDyxvBWkVLxrJKg754LqOScCKbwpUQy8KIgdOJDSc7zDfk08tLLWvNZDVjyh7c28ShmdvMnlgjs2NdS8ISHovP5+hbxGIIs8ayQouyKnXDzBcmS6zw44u86IQ7yl5l+7cngGvWvOVrsEhIC7yNkZPJODkbuAn0g8XN6lPOaVwbuTgxG8OR2DPAb3HTzlqJi8nUoNvCAVf73Mmxo9afSEu4FotzveHSk8c0ZvOMFOqjwP9Sq87iwavIEBg7xIUK68IbozuozZ4btg17c7vx4Hvarr2rtp9IQ8Rt0QO+1jqzyeNzY8kNLzO8sVpry98108OCL9uyisV7vhr4Y8FgaPvLFjczw42og8gWg3vPX6gzsNk/C83GeRPCUVgDy0jpw7yNkZu2VD5zvh93o81h+cuw3Fhzyl5t+86Y7TvHa0EjyzCCi7WmsIPIy1Jzy00Ra6NUiru50rTTx50d47/HKcO2wwETw0f7y8sFIQvNxnkbzS4w855pVBu9FdGzx9yvC6TM80vFQjkzy/Zvs7BhtYPLjKKLqPa787A/6LOyiInbzooSq8728UPIFJ97wq+7q8R6v5u1tYMbwdomG6iSPKPAb3HTx3oTu7fGO8POqtk7ze/ug84wNkPMnq/DsB8iK9ogwOu6lBrDznguo8NQUxvHKcwDo28tm7yNmZPN1UurxCoYS80m7+Oy+9OzzGzTC836MdvCDNCrtaawi7dVLYPEfKuTxzRm88cCmjOyXSBbwGOpi879ZIO8dTJbtqnrO8NMI2vR1+J7xwTV087umfPFG17zsC30s8oYaZPKllZrzZGK47zss9vP21FryZywa9bbYFPVNapDt2G0e7E3SxPMUjgry5dNc895Hbu0H8z7ueN7a7OccxPFhfH7vC1B48n3owvEhQLrzu6Z+8HTutvEBSITw6Taa5g1XgPCzEqbxfLYk9OYQ3vBlm1bvPUTI8wIU7PIy1pzyFyP07gzGmO3NGb7yS3ty7O5CguyEhaLyWoF28pmxUOaZImrz+g/87mnU1vFbsgTxvo668PFmPO2KNTzy09VC8LG5YPHhL6rsvJPC7kTQuvEGCxDlhB9s6u58AvfCAd7z0t4k7kVjoOCkOkrxMjDq8iPOmPL0SnrxsMJG7OEG9vCUa+rvx4rE7cpxAPDCGqjukf6u8TEnAvNn57TweBBw7JdKFvIy1p7vIg8i7" data = [] for i, text in enumerate(input): - obj = Embedding( - embedding=[], - index=i, - object='embedding' - ) + obj = Embedding(embedding=[], index=i, object="embedding") obj.embedding = embeddings data.append(obj) @@ -61,10 +52,7 @@ class MockEmbeddingsClass: return CreateEmbeddingResponse( data=data, model=model, - object='list', + object="list", # marked: usage of embeddings should equal the number of testcase - usage=Usage( - prompt_tokens=2, - total_tokens=2 - ) - ) \ No newline at end of file + usage=Usage(prompt_tokens=2, total_tokens=2), + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 9466f4bfb8..270a88e85f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockModerationClass: - def moderation_create(self: Moderations,*, + def moderation_create( + self: Moderations, + *, input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> ModerationCreateResponse: if isinstance(input, str): input = [input] - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') + raise InvokeAuthorizationError("Invalid API key") for text in input: result = [] - if 'kill' in text: + if "kill" in text: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0, - 'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0, - 'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0 + "harassment": 1.0, + "harassment/threatening": 1.0, + "hate": 1.0, + "hate/threatening": 1.0, + "self-harm": 1.0, + "self-harm/instructions": 1.0, + "self-harm/intent": 1.0, + "sexual": 1.0, + "sexual/minors": 1.0, + "violence": 1.0, + "violence/graphic": 1.0, } - result.append(Moderation( - flagged=True, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=True, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) else: moderation_categories = { - 'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False, - 'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False, - 'sexual/minors': False, 'violence': False, 'violence/graphic': False + "harassment": False, + "harassment/threatening": False, + "hate": False, + "hate/threatening": False, + "self-harm": False, + "self-harm/instructions": False, + "self-harm/intent": False, + "sexual": False, + "sexual/minors": False, + "violence": False, + "violence/graphic": False, } moderation_categories_scores = { - 'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0, - 'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0, - 'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0 + "harassment": 0.0, + "harassment/threatening": 0.0, + "hate": 0.0, + "hate/threatening": 0.0, + "self-harm": 0.0, + "self-harm/instructions": 0.0, + "self-harm/intent": 0.0, + "sexual": 0.0, + "sexual/minors": 0.0, + "violence": 0.0, + "violence/graphic": 0.0, } - result.append(Moderation( - flagged=False, - categories=Categories(**moderation_categories), - category_scores=CategoryScores(**moderation_categories_scores) - )) + result.append( + Moderation( + flagged=False, + categories=Categories(**moderation_categories), + category_scores=CategoryScores(**moderation_categories_scores), + ) + ) - return ModerationCreateResponse( - id='shiroii kuloko', - model=model, - results=result - ) \ No newline at end of file + return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 0124ac045b..cb8f249543 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -6,17 +6,18 @@ from openai.types.model import Model class MockModelClass: """ - mock class for openai.models.Models + mock class for openai.models.Models """ + def list( self, **kwargs, ) -> list[Model]: return [ Model( - id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', + id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ", created=int(time()), - object='model', - owned_by='organization:org-123', + object="model", + owned_by="organization:org-123", ) - ] \ No newline at end of file + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 755fec4c1f..ef361e8613 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError class MockSpeech2TextClass: - def speech2text_create(self: Transcriptions, + def speech2text_create( + self: Transcriptions, *, file: FileTypes, model: Union[str, Literal["whisper-1"]], @@ -17,14 +18,12 @@ class MockSpeech2TextClass: prompt: str | NotGiven = NOT_GIVEN, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, - **kwargs: Any + **kwargs: Any, ) -> Transcription: - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()): - raise InvokeAuthorizationError('Invalid base url') - + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + raise InvokeAuthorizationError("Invalid base url") + if len(self._client.api_key) < 18: - raise InvokeAuthorizationError('Invalid API key') - - return Transcription( - text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10' - ) \ No newline at end of file + raise InvokeAuthorizationError("Invalid API key") + + return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10") diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 7cb0a1318e..777737187e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage class MockXinferenceClass: - def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: - if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): - raise RuntimeError('404 Not Found') - - if 'generate' == model_uid: + def get_chat_model( + self: Client, model_uid: str + ) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: + if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): + raise RuntimeError("404 Not Found") + + if "generate" == model_uid: return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'chat' == model_uid: + if "chat" == model_uid: return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'embedding' == model_uid: + if "embedding" == model_uid: return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - if 'rerank' == model_uid: + if "rerank" == model_uid: return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) - raise RuntimeError('404 Not Found') - + raise RuntimeError("404 Not Found") + def get(self: Session, url: str, **kwargs): response = Response() - if 'v1/models/' in url: + if "v1/models/" in url: # get model uid - model_uid = url.split('/')[-1] or '' - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ - model_uid not in ['generate', 'chat', 'embedding', 'rerank']: + model_uid = url.split("/")[-1] or "" + if not re.match( + r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid + ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response # check if url is valid - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): response.status_code = 404 - response._content = b'{}' + response._content = b"{}" return response - - if model_uid in ['generate', 'chat']: + + if model_uid in ["generate", "chat"]: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "LLM", "address": "127.0.0.1:43877", "accelerators": [ @@ -75,12 +78,12 @@ class MockXinferenceClass: "revision": null, "context_length": 2048, "replica": 1 - }''' + }""" return response - - elif model_uid == 'embedding': + + elif model_uid == "embedding": response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "model_type": "embedding", "address": "127.0.0.1:43877", "accelerators": [ @@ -93,51 +96,48 @@ class MockXinferenceClass: ], "revision": null, "max_tokens": 512 - }''' + }""" return response - - elif 'v1/cluster/auth' in url: + + elif "v1/cluster/auth" in url: response.status_code = 200 - response._content = b'''{ + response._content = b"""{ "auth": true - }''' + }""" return response - + def _check_cluster_authenticated(self): self._cluster_authed = True - - def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict: + + def rerank( + self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool + ) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'rerank': - raise RuntimeError('404 Not Found') - - if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url): - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "rerank" + ): + raise RuntimeError("404 Not Found") + + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): + raise RuntimeError("404 Not Found") if top_n is None: top_n = 1 return { - 'results': [ - { - 'index': i, - 'document': doc, - 'relevance_score': 0.9 - } - for i, doc in enumerate(documents[:top_n]) + "results": [ + {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) ] } - - def create_embedding( - self: RESTfulGenerateModelHandle, - input: Union[str, list[str]], - **kwargs - ) -> dict: + + def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: # check if self._model_uid is a valid uuid - if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ - self._model_uid != 'embedding': - raise RuntimeError('404 Not Found') + if ( + not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) + and self._model_uid != "embedding" + ): + raise RuntimeError("404 Not Found") if isinstance(input, str): input = [input] @@ -147,32 +147,27 @@ class MockXinferenceClass: object="list", model=self._model_uid, data=[ - EmbeddingData( - index=i, - object="embedding", - embedding=[1919.810 for _ in range(768)] - ) + EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) for i in range(ipt_len) ], - usage=EmbeddingUsage( - prompt_tokens=ipt_len, - total_tokens=ipt_len - ) + usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), ) return embedding -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) - monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) - monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) - monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) - monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) + monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) + monkeypatch.setattr(Session, "get", MockXinferenceClass.get) + monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) + monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index 0d54d97daa..8f7e9ec487 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_credentials(setup_anthropic_mock): model = AnthropicLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")} ) -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', + model="claude-instant-1.2", credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), - 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,18 +79,14 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1.2', - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - }, + model="claude-instant-1.2", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 7eaa40dfdd..6f1e50f431 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_anthropic_mock): provider = AnthropicProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index 6afec540ad..8f50ebf7a6 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -17,101 +17,90 @@ from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAIL from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo", + }, ) model.validate_credentials( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "gpt-35-turbo-instruct", + }, ) model.validate_credentials( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo-instruct', + model="gpt-35-turbo-instruct", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo-instruct", }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -122,66 +111,60 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt35', + model="gpt35", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -194,109 +177,87 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-4v', + model="gpt-4v", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-4-vision-preview' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-4-vision-preview", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = AzureOpenAILargeLanguageModel() result = model.invoke( - model='gpt-35-turbo', + model="gpt-35-turbo", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'gpt-35-turbo' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "gpt-35-turbo", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -308,32 +269,22 @@ def test_get_num_tokens(): model = AzureOpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-35-turbo-instruct', - credentials={ - 'base_model_name': 'gpt-35-turbo-instruct' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-35-turbo-instruct", + credentials={"base_model_name": "gpt-35-turbo-instruct"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt35', - credentials={ - 'base_model_name': 'gpt-35-turbo' - }, + model="gpt35", + credentials={"base_model_name": "gpt-35-turbo"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 8b838eb8fc..a1ae2b2e5b 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': 'invalid_key', - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": "invalid_key", + "base_model_name": "text-embedding-ada-002", + }, ) model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' - } + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", + }, ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() result = model.invoke( - model='embedding', + model="embedding", credentials={ - 'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'), - 'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'), - 'base_model_name': 'text-embedding-ada-002' + "openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"), + "openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "base_model_name": "text-embedding-ada-002", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -58,14 +56,7 @@ def test_get_num_tokens(): model = AzureOpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embedding', - credentials={ - 'base_model_name': 'text-embedding-ada-002' - }, - texts=[ - "hello", - "world" - ] + model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index 1cae9a6dd0..ad58610287 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -17,111 +17,99 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = BaichuanLarguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='baichuan2-turbo', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') - } + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), + }, ) + def test_invoke_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_with_system_message(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, prompt_messages=[ - SystemPromptMessage( - content='请记住你是Kasumi。' - ), - UserPromptMessage( - content='现在告诉我你是谁?' - ) + SystemPromptMessage(content="请记住你是Kasumi。"), + UserPromptMessage(content="现在告诉我你是谁?"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -131,34 +119,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = BaichuanLarguageModel() response = model.invoke( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'with_search_enhance': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "with_search_enhance": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -166,25 +151,22 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True total_message += chunk.delta.message.content - assert '不' not in total_message + assert "不" not in total_message + def test_get_num_tokens(): sleep(3) model = BaichuanLarguageModel() response = model.get_num_tokens( - model='baichuan2-turbo', + model="baichuan2-turbo", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), - 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY') + "api_key": os.environ.get("BAICHUAN_API_KEY"), + "secret_key": os.environ.get("BAICHUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 9 \ No newline at end of file + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index 87b3d9a609..4036edfb7a 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = BaichuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 1210ebc53d..cbc63f3978 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = BaichuanTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='baichuan-text-embedding', - credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY') - } + model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")} ) @@ -30,44 +22,40 @@ def test_invoke_model(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = BaichuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = BaichuanTextEmbeddingModel() result = model.invoke( - model='baichuan-text-embedding', + model="baichuan-text-embedding", credentials={ - 'api_key': os.environ.get('BAICHUAN_API_KEY'), + "api_key": os.environ.get("BAICHUAN_API_KEY"), }, texts=[ "hello", @@ -92,8 +80,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 20dc11151a..c19ec35a6e 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -13,77 +13,63 @@ def test_validate_credentials(): model = BedrockLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='meta.llama2-13b-chat-v1', - credentials={ - 'anthropic_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"}) model.validate_credentials( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") - } + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + }, ) + def test_invoke_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'max_tokens_to_sample': 10 - }, - stop=['How'], + model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = BedrockLargeLanguageModel() response = model.invoke( - model='meta.llama2-13b-chat-v1', + model="meta.llama2-13b-chat-v1", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens_to_sample': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,20 +86,18 @@ def test_get_num_tokens(): model = BedrockLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta.llama2-13b-chat-v1', - credentials = { + model="meta.llama2-13b-chat-v1", + credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 18 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index e53d4c1db2..080727829e 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = BedrockProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index e32f01a315..418e88874d 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -23,79 +23,64 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': 'invalid_key' - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"}) - model.validate_credentials( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。' + content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。" ), - UserPromptMessage( - content='波士顿天气如何?' - ) + UserPromptMessage(content="波士顿天气如何?"), ], model_parameters={ - 'temperature': 0, - 'top_p': 1.0, + "temperature": 0, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=True, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, Generator) - + call: LLMResultChunk = None chunks = [] @@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock): break assert call is not None - assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + assert call.delta.message.tool_calls[0].function.name == "get_current_weather" -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() response = model.invoke( - model='chatglm3-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, - prompt_messages=[ - UserPromptMessage( - content='What is the weather like in San Francisco?' - ) - ], + model="chatglm3-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, + prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], - user='abc-123', + stop=["you"], + user="abc-123", stream=False, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 - assert response.message.tool_calls[0].function.name == 'get_current_weather' + assert response.message.tool_calls[0].function.name == "get_current_weather" def test_get_num_tokens(): model = ChatGLMLargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='chatglm2-6b', - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - }, + model="chatglm2-6b", + credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index e9c5c4da75..7907805d07 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = ChatGLMProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_base': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_base": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_base': os.environ.get('CHATGLM_API_BASE') - } - ) + provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index 5ce4f8ecfe..b7f707e935 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_validate_credentials_for_completion_model(): model = CohereLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_completion_model(): model = CohereLargeLanguageModel() - credentials = { - 'api_key': os.environ.get('COHERE_API_KEY') - } + credentials = {"api_key": os.environ.get("COHERE_API_KEY")} result = model.invoke( - model='command-light', + model="command-light", credentials=credentials, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 + assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1 def test_invoke_stream_completion_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -109,28 +71,24 @@ def test_invoke_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'p': 0.99, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "p": 0.99, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -141,24 +99,17 @@ def test_invoke_stream_chat_model(): model = CohereLargeLanguageModel() result = model.invoke( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -177,32 +128,22 @@ def test_get_num_tokens(): model = CohereLargeLanguageModel() num_tokens = model.get_num_tokens( - model='command-light', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="command-light", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='command-light-chat', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="command-light-chat", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 15 @@ -213,25 +154,17 @@ def test_fine_tuned_model(): # test invoke result = model.invoke( - model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'completion' - }, + model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -242,25 +175,17 @@ def test_fine_tuned_chat_model(): # test invoke result = model.invoke( - model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY'), - 'mode': 'chat' - }, + model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft", + credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index a8f56b6194..fb7e6d3498 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = CohereProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index 415c5fbfda..a1b6922128 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -11,29 +11,17 @@ def test_validate_credentials(): model = CohereRerankModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } - ) + model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}) def test_invoke_model(): model = CohereRerankModel() result = model.invoke( - model='rerank-english-v2.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, + model="rerank-english-v2.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, query="What is the capital of the United States?", docs=[ "Carson City is the capital city of the American state of Nevada. At the 2010 United States " @@ -41,9 +29,9 @@ def test_invoke_model(): "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) " "is the capital of the United States. It is a federal district. The President of the USA and many major " "national government offices are in the territory. This makes it the political center of the United " - "States of America." + "States of America.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 5017ba47e1..ae26d36635 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = CohereTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - } + model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")} ) @@ -30,17 +22,10 @@ def test_invoke_model(): model = CohereTextEmbeddingModel() result = model.invoke( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -52,14 +37,9 @@ def test_get_num_tokens(): model = CohereTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embed-multilingual-v3.0', - credentials={ - 'api_key': os.environ.get('COHERE_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="embed-multilingual-v3.0", + credentials={"api_key": os.environ.get("COHERE_API_KEY")}, + texts=["hello", "world"], ) assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 00d907d19e..4d9d490a87 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -16,103 +16,73 @@ from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguag from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_credentials(setup_google_mock): model = GoogleLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": "invalid_key"}) - model.validate_credentials( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + model.validate_credentials(model="gemini-pro", credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.5, - 'top_p': 1.0, - 'max_tokens_to_sample': 2048 - }, - stop=['How'], + model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_stream_model(setup_google_mock): model = GoogleLargeLanguageModel() response = model.invoke( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', - ), - UserPromptMessage( - content='Give me your worst dad joke or i will unplug you' + content="You are a helpful AI assistant.", ), + UserPromptMessage(content="Give me your worst dad joke or i will unplug you"), AssistantPromptMessage( - content='Why did the scarecrow win an award? Because he was outstanding in his field!' + content="Why did the scarecrow win an award? Because he was outstanding in his field!" ), UserPromptMessage( content=[ - TextPromptMessageContent( - data="ok something snarkier pls" - ), - TextPromptMessageContent( - data="i may still unplug you" - )] - ) + TextPromptMessageContent(data="ok something snarkier pls"), + TextPromptMessageContent(data="i may still unplug you"), + ] + ), ], - model_parameters={ - 'temperature': 0.2, - 'top_k': 5, - 'max_tokens_to_sample': 2048 - }, + model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,88 +93,66 @@ def test_invoke_stream_model(setup_google_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), + content=[ + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): model = GoogleLargeLanguageModel() result = model.invoke( - model='gemini-pro-vision', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro-vision", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ - SystemPromptMessage( - content='You are a helpful AI assistant.' - ), - UserPromptMessage( - content=[ - TextPromptMessageContent( - data="what do you see?" - ), - ImagePromptMessageContent( - data='' - ) - ] - ), - AssistantPromptMessage( - content="I see a blue letter 'D' with a gradient from light blue to dark blue." - ), + SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage( content=[ - TextPromptMessageContent( - data="what about now?" - ), + TextPromptMessageContent(data="what do you see?"), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), + AssistantPromptMessage(content="I see a blue letter 'D' with a gradient from light blue to dark blue."), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="what about now?"), + ImagePromptMessageContent( + data="" + ), + ] + ), ], - model_parameters={ - 'temperature': 0.3, - 'top_p': 0.2, - 'top_k': 3, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.3, "top_p": 0.2, "top_k": 3, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) print(f"resultz: {result.message.content}") @@ -212,23 +160,18 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): assert len(result.message.content) > 0 - def test_get_num_tokens(): model = GoogleLargeLanguageModel() num_tokens = model.get_num_tokens( - model='gemini-pro', - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - }, + model="gemini-pro", + credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 103107ed5a..c217e4fe05 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider from tests.integration_tests.model_runtime.__mock.google import setup_google_mock -@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) def test_validate_provider_credentials(setup_google_mock): provider = GoogleProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'google_api_key': os.environ.get('GOOGLE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 28cd0955b3..6a6cc874fa 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="HuggingFaceH4/zephyr-7b-beta", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='fake-model', - credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key' - } + model="fake-model", + credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"}, ) model.validate_credentials( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='HuggingFaceH4/zephyr-7b-beta', + model="HuggingFaceH4/zephyr-7b-beta", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY') + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) model.validate_credentials( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='openchat/openchat_3.5', + model="openchat/openchat_3.5", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"), + "task_type": "text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) model.validate_credentials( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", + }, ) -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 -@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True) def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() response = model.invoke( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], + prompt_messages=[UserPromptMessage(content="Who are you?")], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -286,18 +264,14 @@ def test_get_num_tokens(): model = HuggingfaceHubLargeLanguageModel() num_tokens = model.get_num_tokens( - model='google/mt5-base', + model="google/mt5-base", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'), - 'task_type': 'text2text-generation' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"), + "task_type": "text2text-generation", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index d03b3186cb..0ee593f38a 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': 'invalid_key', - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": "invalid_key", + }, ) model.validate_credentials( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - } + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + }, ) @@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='facebook/bart-base', + model="facebook/bart-base", credentials={ - 'huggingfacehub_api_type': 'hosted_inference_api', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), + "huggingfacehub_api_type": "hosted_inference_api", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': 'invalid_key', - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": "invalid_key", + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) model.validate_credentials( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' - } + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", + }, ) @@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model(): model = HuggingfaceHubTextEmbeddingModel() result = model.invoke( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert isinstance(result, TextEmbeddingResult) @@ -104,18 +98,15 @@ def test_get_num_tokens(): model = HuggingfaceHubTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='all-MiniLM-L6-v2', + model="all-MiniLM-L6-v2", credentials={ - 'huggingfacehub_api_type': 'inference_endpoints', - 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'), - 'huggingface_namespace': 'Dify-AI', - 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'), - 'task_type': 'feature-extraction' + "huggingfacehub_api_type": "inference_endpoints", + "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"), + "huggingface_namespace": "Dify-AI", + "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"), + "task_type": "feature-extraction", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index ed371fbc07..b1fa9d5ca5 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe ) from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() # model name is only used in mock - model_name = 'embedding' + model_name = "embedding" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='reranker', + model="reranker", credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), - } + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiTextEmbeddingModel() - model_name = 'embedding' + model_name = "embedding" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py index 57e229e6be..45370d9fba 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): if MOCK: - monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) - monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) - monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) - monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank) yield if MOCK: monkeypatch.undo() -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_validate_credentials(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" if MOCK: # TEI Provider will check model type by API endpoint, at real server, the model type is correct. # So we dont need to check model type here. Only check in mock with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embedding', + model="embedding", credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) model.validate_credentials( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), - } + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + }, ) -@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True) def test_invoke_model(setup_tei_mock): model = HuggingfaceTeiRerankModel() # model name is only used in mock - model_name = 'reranker' + model_name = "reranker" result = model.invoke( model=model_name, credentials={ - 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py index 305f967ef0..b3049a06d9 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_llm.py @@ -14,19 +14,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-standard', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -34,23 +30,16 @@ def test_invoke_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -61,23 +50,15 @@ def test_invoke_stream_model(): model = HunyuanLargeLanguageModel() response = model.invoke( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hi' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hi")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,19 +74,17 @@ def test_get_num_tokens(): model = HunyuanLargeLanguageModel() num_tokens = model.get_num_tokens( - model='hunyuan-standard', + model="hunyuan-standard", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py index bdec3d0e22..e3748c2ce7 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_provider.py @@ -10,16 +10,11 @@ def test_validate_provider_credentials(): provider = HunyuanProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } - ) + provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}) provider.validate_provider_credentials( credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py index 7ae6c0e456..69d14dffee 100644 --- a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -12,19 +12,15 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='hunyuan-embedding', - credentials={ - 'secret_id': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') - } + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), + }, ) @@ -32,47 +28,43 @@ def test_invoke_model(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = HunyuanTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 + def test_max_chunks(): model = HunyuanTextEmbeddingModel() result = model.invoke( - model='hunyuan-embedding', + model="hunyuan-embedding", credentials={ - 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), - 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + "secret_id": os.environ.get("HUNYUAN_SECRET_ID"), + "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"), }, texts=[ "hello", @@ -97,8 +89,8 @@ def test_max_chunks(): "world", "hello", "world", - ] + ], ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 2b43248388..e3b6128c59 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -10,14 +10,6 @@ def test_validate_provider_credentials(): provider = JinaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index ac17566174..290735ec49 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -11,18 +11,10 @@ def test_validate_credentials(): model = JinaTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"}) model.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials={ - 'api_key': os.environ.get('JINA_API_KEY') - } + model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")} ) @@ -30,15 +22,12 @@ def test_invoke_model(): model = JinaTextEmbeddingModel() result = model.invoke( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +39,11 @@ def test_get_num_tokens(): model = JinaTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='jina-embeddings-v2-base-en', + model="jina-embeddings-v2-base-en", credentials={ - 'api_key': os.environ.get('JINA_API_KEY'), + "api_key": os.environ.get("JINA_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 6 diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py index e05345ee56..7fd9f2b300 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -1,4 +1,4 @@ """ - LocalAI Embedding Interface is temporarily unavailable due to - we could not find a way to test it for now. -""" \ No newline at end of file +LocalAI Embedding Interface is temporarily unavailable due to +we could not find a way to test it for now. +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index 6f421403d4..aa5436c34f 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', - }, - prompt_messages=[ - UserPromptMessage( - content='ping' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, stop=[], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_completion_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -123,28 +102,21 @@ def test_invoke_stream_completion_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_stream_chat_model(): model = LocalAILanguageModel() response = model.invoke( - model='chinese-llama-2-7b', + model="chinese-llama-2-7b", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'max_tokens': 10 - }, - stop=['you'], + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -154,64 +126,48 @@ def test_invoke_stream_chat_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = LocalAILanguageModel() num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='????', + model="????", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'chat_completion', + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "chat_completion", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py index 99847bc852..13c7df6d14 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-v2-m3', + model="bge-reranker-v2-m3", credentials={ - 'server_url': 'hahahaha', - 'completion_type': 'completion', - } + "server_url": "hahahaha", + "completion_type": "completion", + }, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL'), - 'completion_type': 'completion', - } + "server_url": os.environ.get("LOCALAI_SERVER_URL"), + "completion_type": "completion", + }, ) + def test_invoke_rerank_model(): model = LocalaiRerankModel() response = model.invoke( - model='bge-reranker-base', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -45,43 +44,38 @@ def test_invoke_rerank_model(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(response, RerankResult) assert len(response.docs) == 3 + def test__invoke(): model = LocalaiRerankModel() # Test case 1: Empty docs result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 0 # Test case 2: Valid invocation result = model._invoke( - model='bge-reranker-base', - credentials={ - 'server_url': 'https://example.com', - 'api_key': '1234567890' - }, - query='Organic skincare products for sensitive skin', + model="bge-reranker-base", + credentials={"server_url": "https://example.com", "api_key": "1234567890"}, + query="Organic skincare products for sensitive skin", docs=[ "Eco-friendly kitchenware for modern homes", "Biodegradable cleaning supplies for eco-conscious consumers", @@ -91,12 +85,12 @@ def test__invoke(): "Sustainable gardening tools and compost solutions", "Sensitive skin-friendly facial cleansers and toners", "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials" + "Yoga mats made from recycled materials", ], top_n=3, score_threshold=0.75, - user="abc-123" + user="abc-123", ) assert isinstance(result, RerankResult) assert len(result.docs) == 3 - assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py index 3fd2ebed4f..91b7a5752c 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/localai/test_speech2text.py @@ -10,19 +10,9 @@ def test_validate_credentials(): model = LocalAISpeech2text() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': 'invalid_url' - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - } - ) + model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}) def test_invoke_model(): @@ -32,23 +22,21 @@ def test_invoke_model(): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'server_url': os.environ.get('LOCALAI_SERVER_URL') - }, + model="whisper-1", + credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' \ No newline at end of file + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 6f4b8a163f..cf2a28eb9e 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -12,54 +12,47 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='embo-01', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + model="embo-01", + credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")}, ) model.validate_credentials( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): model = MinimaxTextEmbeddingModel() result = model.invoke( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens == 16 + def test_get_num_tokens(): model = MinimaxTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='embo-01', + model="embo-01", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 570e4901a9..aacde04d32 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -17,79 +17,70 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = MinimaxLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='abab5.5-chat', - credentials={ - 'minimax_api_key': 'invalid_key', - 'minimax_group_id': 'invalid_key' - } + model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"} ) model.validate_credentials( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - } + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), + }, ) + def test_invoke_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5-chat', + model="abab5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -99,34 +90,31 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = MinimaxLargeLanguageModel() response = model.invoke( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -134,25 +122,22 @@ def test_invoke_with_search(): total_message += chunk.delta.message.content assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True - assert '参考资料' in total_message + assert "参考资料" in total_message + def test_get_num_tokens(): sleep(3) model = MinimaxLargeLanguageModel() response = model.get_num_tokens( - model='abab5.5-chat', + model="abab5.5-chat", credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 30 \ No newline at end of file + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 4c5462c6df..575ed13eef 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -12,14 +12,14 @@ def test_validate_provider_credentials(): with pytest.raises(CredentialsValidateFailedError): provider.validate_provider_credentials( credentials={ - 'minimax_api_key': 'hahahaha', - 'minimax_group_id': '123', + "minimax_api_key": "hahahaha", + "minimax_group_id": "123", } ) provider.validate_provider_credentials( credentials={ - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'), + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), } ) diff --git a/api/tests/integration_tests/model_runtime/novita/test_llm.py b/api/tests/integration_tests/model_runtime/novita/test_llm.py index 4ebc68493f..35fa0dc190 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_llm.py +++ b/api/tests/integration_tests/model_runtime/novita/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - } + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'completion' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_p': 0.5, - 'max_tokens': 10, + "temperature": 1.0, + "top_p": 0.5, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="novita" + user="novita", ) assert isinstance(response, LLMResult) @@ -70,27 +58,17 @@ def test_invoke_stream_model(): model = NovitaLargeLanguageModel() response = model.invoke( - model='meta-llama/llama-3-8b-instruct', - credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), - 'mode': 'chat' - }, + model="meta-llama/llama-3-8b-instruct", + credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'max_tokens': 100 - }, + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100}, stream=True, - user="novita" + user="novita", ) assert isinstance(response, Generator) @@ -105,18 +83,16 @@ def test_get_num_tokens(): model = NovitaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='meta-llama/llama-3-8b-instruct', + model="meta-llama/llama-3-8b-instruct", credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/novita/test_provider.py b/api/tests/integration_tests/model_runtime/novita/test_provider.py index bb3f19dc85..191af99db2 100644 --- a/api/tests/integration_tests/model_runtime/novita/test_provider.py +++ b/api/tests/integration_tests/model_runtime/novita/test_provider.py @@ -10,12 +10,10 @@ def test_validate_provider_credentials(): provider = NovitaProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'api_key': os.environ.get('NOVITA_API_KEY'), + "api_key": os.environ.get("NOVITA_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 272e639a8a..58a1339f50 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -20,23 +20,23 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, + }, ) @@ -44,26 +44,17 @@ def test_invoke_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -74,29 +65,22 @@ def test_invoke_stream_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -111,26 +95,17 @@ def test_invoke_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=False + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=False, ) assert isinstance(response, LLMResult) @@ -141,29 +116,22 @@ def test_invoke_stream_completion_model(): model = OllamaLargeLanguageModel() response = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], - model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, - 'num_predict': 10 - }, - stop=['How'], - stream=True + model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "num_predict": 10}, + stop=["How"], + stream=True, ) assert isinstance(response, Generator) @@ -178,29 +146,26 @@ def test_invoke_completion_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'completion', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "completion", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -212,29 +177,26 @@ def test_invoke_chat_model_with_vision(): model = OllamaLargeLanguageModel() result = model.invoke( - model='llava', + model="llava", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, prompt_messages=[ UserPromptMessage( content=[ TextPromptMessageContent( - data='What is this in this picture?', + data="What is this in this picture?", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] ) ], - model_parameters={ - 'temperature': 0.1, - 'num_predict': 100 - }, + model_parameters={"temperature": 0.1, "num_predict": 100}, stream=False, ) @@ -246,18 +208,14 @@ def test_get_num_tokens(): model = OllamaLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 2048, - 'max_tokens': 2048, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 2048, + "max_tokens": 2048, }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index c5f5918235..3c4f740a4f 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -12,21 +12,21 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': 'http://localhost:21434', - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": "http://localhost:21434", + "mode": "chat", + "context_size": 4096, + }, ) model.validate_credentials( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, - } + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, + }, ) @@ -34,17 +34,14 @@ def test_invoke_model(): model = OllamaEmbeddingModel() result = model.invoke( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -56,16 +53,13 @@ def test_get_num_tokens(): model = OllamaEmbeddingModel() num_tokens = model.get_num_tokens( - model='mistral:text', + model="mistral:text", credentials={ - 'base_url': os.environ.get('OLLAMA_BASE_URL'), - 'mode': 'chat', - 'context_size': 4096, + "base_url": os.environ.get("OLLAMA_BASE_URL"), + "mode": "chat", + "context_size": 4096, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index bf4ac53579..3b3ea9ec80 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -28,92 +28,61 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-davinci-003", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-davinci-003', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-davinci-003", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1 - }, + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 1}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 - assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1 + assert model._num_tokens_from_string("gpt-3.5-turbo-instruct", result.message.content) == 1 -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo-instruct', + model="gpt-3.5-turbo-instruct", credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'), - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "openai_organization": os.environ.get("OPENAI_ORGANIZATION"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -124,166 +93,131 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-4-vision-preview', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-4-vision-preview", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content=[ TextPromptMessageContent( - data='Hello World!', + data="Hello World!", ), ImagePromptMessageContent( - data='' - ) + data="" + ), ] - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -302,68 +236,46 @@ def test_get_num_tokens(): model = OpenAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo-instruct', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="gpt-3.5-turbo-instruct", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 3 num_tokens = model.get_num_tokens( - model='gpt-3.5-turbo', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="gpt-3.5-turbo", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 72 -@pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat", "remote"]], indirect=True) def test_fine_tuned_models(setup_openai_mock): model = OpenAILargeLanguageModel() - remote_models = model.remote_models(credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }) + remote_models = model.remote_models(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) if not remote_models: assert isinstance(remote_models, list) @@ -379,29 +291,23 @@ def test_fine_tuned_models(setup_openai_mock): # test invoke result = model.invoke( model=llm_model.model, - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) + def test__get_num_tokens_by_gpt2(): model = OpenAILargeLanguageModel() - num_tokens = model._get_num_tokens_by_gpt2('Hello World!') + num_tokens = model._get_num_tokens_by_gpt2("Hello World!") assert num_tokens == 3 diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 04f9b9f33b..6de2624717 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -7,48 +7,37 @@ from core.model_runtime.model_providers.openai.moderation.moderation import Open from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAIModerationModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAIModerationModel() result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="hello", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) assert result is False result = model.invoke( - model='text-moderation-stable', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="text-moderation-stable", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, text="i will kill you", - user="abc-123" + user="abc-123", ) assert isinstance(result, bool) diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index 5314bffbdf..4d56cfcf3c 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.openai.openai import OpenAIProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = OpenAIProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index f1a5c4fd23..aa92c8b61f 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -7,26 +7,17 @@ from core.model_runtime.model_providers.openai.speech2text.speech2text import Op from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAISpeech2TextModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"}) - model.validate_credentials( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) + model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) -@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAISpeech2TextModel() @@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock): current_dir = os.path.dirname(os.path.abspath(__file__)) # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), 'assets') + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, 'audio.mp3') + audio_file_path = os.path.join(assets_dir, "audio.mp3") # Open the file and get the file object - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: file = audio_file result = model.invoke( - model='whisper-1', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - }, + model="whisper-1", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, file=file, - user="abc-123" + user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index e2c4c74ee7..f5dd73f2d4 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -8,42 +8,27 @@ from core.model_runtime.model_providers.openai.text_embedding.text_embedding imp from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = OpenAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"}) model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAITextEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -55,15 +40,9 @@ def test_get_num_tokens(): model = OpenAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY'), - 'openai_api_base': 'https://api.openai.com' - }, - texts=[ - "hello", - "world" - ] + model="text-embedding-ada-002", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"}, + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c833508569..f2302ef05e 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -23,21 +23,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"}, ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' - } + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + }, ) @@ -45,28 +41,26 @@ def test_invoke_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'completion' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "completion", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -77,29 +71,27 @@ def test_invoke_stream_model(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat', - 'stream_mode_delimiter': '\\n\\n' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", + "stream_mode_delimiter": "\\n\\n", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter(): model = OAIAPICompatLargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/', - 'mode': 'chat' + "api_key": os.environ.get("TOGETHER_API_KEY"), + "endpoint_url": "https://api.together.xyz/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools(): model = OAIAPICompatLargeLanguageModel() result = model.invoke( - model='gpt-3.5-turbo', + model="gpt-3.5-turbo", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'mode': 'chat' + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "mode": "chat", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 1024 - }, + model_parameters={"temperature": 0.0, "max_tokens": 1024}, stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) @@ -207,19 +183,14 @@ def test_get_num_tokens(): model = OAIAPICompatLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py index 61079104dc..cf805eafff 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py @@ -14,18 +14,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="whisper-1", - credentials={ - "api_key": "invalid_key", - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"}, ) model.validate_credentials( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, ) @@ -47,13 +41,10 @@ def test_invoke_model(): result = model.invoke( model="whisper-1", - credentials={ - "api_key": os.environ.get("OPENAI_API_KEY"), - "endpoint_url": "https://api.openai.com/v1/" - }, + credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"}, file=file, user="abc-123", ) assert isinstance(result, str) - assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 77d27ec161..052b41605f 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -12,27 +12,23 @@ from core.model_runtime.model_providers.openai_api_compatible.text_embedding.tex Using OpenAI's API as testing endpoint """ + def test_validate_credentials(): model = OAICompatEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='text-embedding-ada-002', - credentials={ - 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - - } + model="text-embedding-ada-002", + credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184}, ) model.validate_credentials( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 - } + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, + }, ) @@ -40,19 +36,14 @@ def test_invoke_model(): model = OAICompatEmbeddingModel() result = model.invoke( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/", + "context_size": 8184, }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -64,16 +55,13 @@ def test_get_num_tokens(): model = OAICompatEmbeddingModel() num_tokens = model.get_num_tokens( - model='text-embedding-ada-002', + model="text-embedding-ada-002", credentials={ - 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184 + "api_key": os.environ.get("OPENAI_API_KEY"), + "endpoint_url": "https://api.openai.com/v1/embeddings", + "context_size": 8184, }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) - assert num_tokens == 2 \ No newline at end of file + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 9eb05a111d..14d47217af 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -12,17 +12,17 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"), + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) @@ -30,33 +30,28 @@ def test_invoke_model(): model = OpenLLMTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = OpenLLMTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 853a0fbe3c..35939e3cfe 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': 'invalid_key', - } + "server_url": "invalid_key", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), - } + "server_url": os.environ.get("OPENLLM_SERVER_URL"), + }, ) + def test_invoke_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): model = OpenLLMLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -84,21 +78,18 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = OpenLLMLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'server_url': os.environ.get('OPENLLM_SERVER_URL'), + "server_url": os.environ.get("OPENLLM_SERVER_URL"), }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 3 \ No newline at end of file + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py index 8f1fb4c4ad..ce4876a73a 100644 --- a/api/tests/integration_tests/model_runtime/openrouter/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openrouter/test_llm.py @@ -19,19 +19,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) @@ -39,27 +32,22 @@ def test_invoke_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -70,27 +58,22 @@ def test_invoke_stream_model(): model = OpenRouterLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/mixtral-8x7b-instruct", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -105,18 +88,16 @@ def test_get_num_tokens(): model = OpenRouterLargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/mixtral-8x7b-instruct', + model="mistralai/mixtral-8x7b-instruct", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index e248f064c0..b940005b71 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -14,19 +14,19 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": "invalid_key", + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) model.validate_credentials( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + }, ) @@ -34,27 +34,25 @@ def test_invoke_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='meta/llama-2-13b-chat', + model="meta/llama-2-13b-chat", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,27 +63,25 @@ def test_invoke_stream_model(): model = ReplicateLargeLanguageModel() response = model.invoke( - model='mistralai/mixtral-8x7b-instruct-v0.1', + model="mistralai/mixtral-8x7b-instruct-v0.1", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -100,19 +96,17 @@ def test_get_num_tokens(): model = ReplicateLargeLanguageModel() num_tokens = model.get_num_tokens( - model='', + model="", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e", }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 5708ec9e5a..397715f225 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -12,19 +12,19 @@ def test_validate_credentials_one(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": "invalid_key", + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) model.validate_credentials( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", + }, ) @@ -33,19 +33,19 @@ def test_validate_credentials_two(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': 'invalid_key', - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": "invalid_key", + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) model.validate_credentials( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' - } + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", + }, ) @@ -53,16 +53,13 @@ def test_invoke_model_one(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/bge-large-en-v1.5', + model="nateraw/bge-large-en-v1.5", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -74,16 +71,13 @@ def test_invoke_model_two(): model = ReplicateEmbeddingModel() result = model.invoke( - model='andreasjansson/clip-features', + model="andreasjansson/clip-features", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -95,16 +89,13 @@ def test_invoke_model_three(): model = ReplicateEmbeddingModel() result = model.invoke( - model='replicate/all-mpnet-base-v2', + model="replicate/all-mpnet-base-v2", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -116,16 +107,13 @@ def test_invoke_model_four(): model = ReplicateEmbeddingModel() result = model.invoke( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -137,15 +125,12 @@ def test_get_num_tokens(): model = ReplicateEmbeddingModel() num_tokens = model.get_num_tokens( - model='nateraw/jina-embeddings-v2-base-en', + model="nateraw/jina-embeddings-v2-base-en", credentials={ - 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'), - 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e' + "replicate_api_token": os.environ.get("REPLICATE_API_KEY"), + "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py index 639227e745..9f0b439d6c 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -10,10 +10,6 @@ def test_validate_provider_credentials(): provider = SageMakerProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py index c67849dd79..d5a6798a1e 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -12,11 +12,11 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -25,7 +25,7 @@ def test_validate_credentials(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) @@ -33,11 +33,11 @@ def test_invoke_model(): model = SageMakerRerankModel() result = model.invoke( - model='bge-m3-rerank-v2', + model="bge-m3-rerank-v2", credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), - "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), }, query="What is the capital of the United States?", docs=[ @@ -46,7 +46,7 @@ def test_invoke_model(): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py index e817e8f04a..e4e404c7a8 100644 --- a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -11,45 +11,23 @@ def test_validate_credentials(): model = SageMakerEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='bge-m3', - credentials={ - } - ) + model.validate_credentials(model="bge-m3", credentials={}) - model.validate_credentials( - model='bge-m3-embedding', - credentials={ - } - ) + model.validate_credentials(model="bge-m3-embedding", credentials={}) def test_invoke_model(): model = SageMakerEmbeddingModel() - result = model.invoke( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - "hello", - "world" - ], - user="abc-123" - ) + result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123") assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 + def test_get_num_tokens(): model = SageMakerEmbeddingModel() - num_tokens = model.get_num_tokens( - model='bge-m3-embedding', - credentials={ - }, - texts=[ - ] - ) + num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[]) assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py index befdd82352..f47c9c5588 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = SiliconflowLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")}) def test_invoke_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = SiliconflowLargeLanguageModel() response = model.invoke( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = SiliconflowLargeLanguageModel() num_tokens = model.get_num_tokens( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials={ - 'api_key': os.environ.get('API_KEY') - }, + model="deepseek-ai/DeepSeek-V2-Chat", + credentials={"api_key": os.environ.get("API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py index 7b9211a5db..8f70210b7a 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = SiliconflowProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py index 7b3ff82727..ad794613f9 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -13,9 +13,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-reranker-v2-m3", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( @@ -30,17 +28,17 @@ def test_invoke_model(): model = SiliconflowRerankModel() result = model.invoke( - model='BAAI/bge-reranker-v2-m3', + model="BAAI/bge-reranker-v2-m3", credentials={ "api_key": os.environ.get("API_KEY"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py index 82b7921c85..0502ba5ab4 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py @@ -12,16 +12,12 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, + credentials={"api_key": os.environ.get("API_KEY")}, ) @@ -42,12 +38,8 @@ def test_invoke_model(): file = audio_file result = model.invoke( - model="iic/SenseVoiceSmall", - credentials={ - "api_key": os.environ.get("API_KEY") - }, - file=file + model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file ) assert isinstance(result, str) - assert result == '1,2,3,4,5,6,7,8,9,10.' + assert result == "1,2,3,4,5,6,7,8,9,10." diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py index 18bd2e893a..ab143c1061 100644 --- a/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_text_embedding.py @@ -15,9 +15,7 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="BAAI/bge-large-zh-v1.5", - credentials={ - "api_key": "invalid_key" - }, + credentials={"api_key": "invalid_key"}, ) model.validate_credentials( diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 706316449d..4fe2fd8c0a 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -13,20 +13,15 @@ def test_validate_credentials(): model = SparkLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='spark-1.5', - credentials={ - 'app_id': 'invalid_key' - } - ) + model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"}) model.validate_credentials( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - } + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), + }, ) @@ -34,24 +29,17 @@ def test_invoke_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -62,23 +50,16 @@ def test_invoke_stream_model(): model = SparkLargeLanguageModel() response = model.invoke( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100 + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -94,20 +75,18 @@ def test_get_num_tokens(): model = SparkLargeLanguageModel() num_tokens = model.get_num_tokens( - model='spark-1.5', + model="spark-1.5", credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8e22815a86..9da0df6bb3 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -10,14 +10,12 @@ def test_validate_provider_credentials(): provider = SparkProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( credentials={ - 'app_id': os.environ.get('SPARK_APP_ID'), - 'api_secret': os.environ.get('SPARK_API_SECRET'), - 'api_key': os.environ.get('SPARK_API_KEY') + "app_id": os.environ.get("SPARK_APP_ID"), + "api_secret": os.environ.get("SPARK_API_SECRET"), + "api_key": os.environ.get("SPARK_API_KEY"), } ) diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py index d703147d63..c03b1bae1f 100644 --- a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -21,40 +21,22 @@ def test_validate_credentials(): model = StepfunLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"}) + + model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}) - model.validate_credentials( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } - ) def test_invoke_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['Hi'], + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["Hi"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -65,24 +47,17 @@ def test_invoke_stream_model(): model = StepfunLargeLanguageModel() response = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,10 +73,7 @@ def test_get_customizable_model_schema(): model = StepfunLargeLanguageModel() schema = model.get_customizable_model_schema( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - } + model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")} ) assert isinstance(schema, AIModelEntity) @@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools(): model = StepfunLargeLanguageModel() result = model.invoke( - model='step-1-8k', - credentials={ - 'api_key': os.environ.get('STEPFUN_API_KEY') - }, + model="step-1-8k", + credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in Shanghai?", - ) + ), ], - model_parameters={ - 'temperature': 0.9, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.9, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) - assert len(result.message.tool_calls) > 0 \ No newline at end of file + assert len(result.message.tool_calls) > 0 diff --git a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py index fd8aa3f610..0ec4b0b724 100644 --- a/api/tests/integration_tests/model_runtime/test_model_provider_factory.py +++ b/api/tests/integration_tests/model_runtime/test_model_provider_factory.py @@ -24,13 +24,8 @@ def test_get_models(): providers = factory.get_models( model_type=ModelType.LLM, provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) logger.debug(providers) @@ -44,29 +39,21 @@ def test_get_models(): assert provider_model.model_type == ModelType.LLM providers = factory.get_models( - provider='openai', + provider="openai", provider_configs=[ - ProviderConfig( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } - ) - ] + ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}) + ], ) assert len(providers) == 1 assert isinstance(providers[0], SimpleProviderEntity) - assert providers[0].provider == 'openai' + assert providers[0].provider == "openai" def test_provider_credentials_validate(): factory = ModelProviderFactory() factory.provider_credentials_validate( - provider='openai', - credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")} ) @@ -79,4 +66,4 @@ def test__get_model_provider_map(): logger.debug(model_provider.provider_instance) assert len(model_providers) >= 1 - assert isinstance(model_providers['openai'], ModelProviderExtension) + assert isinstance(model_providers["openai"], ModelProviderExtension) diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 698f534517..06ebc2a82d 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -19,76 +19,61 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': 'invalid_key', - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"} ) model.validate_credentials( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - } + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, ) + def test_invoke_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'completion' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = TogetherAILargeLanguageModel() response = model.invoke( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', - credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'mode': 'chat' - }, + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Who are you?' - ) + UserPromptMessage(content="Who are you?"), ], model_parameters={ - 'temperature': 1.0, - 'top_k': 2, - 'top_p': 0.5, + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, }, - stop=['How'], + stop=["How"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -98,22 +83,21 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + def test_get_num_tokens(): model = TogetherAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='mistralai/Mixtral-8x7B-Instruct-v0.1', + model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={ - 'api_key': os.environ.get('TOGETHER_API_KEY'), + "api_key": os.environ.get("TOGETHER_API_KEY"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert isinstance(num_tokens, int) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 81fb676018..61650735f2 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -13,18 +13,10 @@ def test_validate_credentials(): model = TongyiLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"}) model.validate_credentials( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) @@ -32,22 +24,13 @@ def test_invoke_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +41,12 @@ def test_invoke_stream_model(): model = TongyiLargeLanguageModel() response = model.invoke( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +62,14 @@ def test_get_num_tokens(): model = TongyiLargeLanguageModel() num_tokens = model.get_num_tokens( - model='qwen-turbo', - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + model="qwen-turbo", + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 12 diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 6145c1dc37..0bc96c84e7 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -10,12 +10,8 @@ def test_validate_provider_credentials(): provider = TongyiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) provider.validate_provider_credentials( - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - } + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py index 1b0a38d5d1..905e7907fd 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_response_format.py @@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"): response = model.invoke( model=model_name, - credentials={ - 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') - }, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, prompt_messages=[ - UserPromptMessage( - content='output json data with format `{"data": "test", "code": 200, "msg": "success"}' - ) + UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}') ], model_parameters={ - 'temperature': 0.5, - 'max_tokens': 50, - 'response_format': 'JSON', + "temperature": 0.5, + "max_tokens": 50, + "response_format": "JSON", }, stream=True, - user="abc-123" + user="abc-123", ) print("=====================================") print(response) @@ -81,4 +77,4 @@ def is_json(s): json.loads(s) except ValueError: return False - return True \ No newline at end of file + return True diff --git a/api/tests/integration_tests/model_runtime/upstage/test_llm.py b/api/tests/integration_tests/model_runtime/upstage/test_llm.py index c35580a8b1..bc7517acbe 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_llm.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_llm.py @@ -26,151 +26,113 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): # model name to gpt-3.5-turbo because of mocking - model.validate_credentials( - model='gpt-3.5-turbo', - credentials={ - 'upstage_api_key': 'invalid_key' - } - ) + model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"}) model.validate_credentials( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.0, - 'top_p': 1.0, - 'presence_penalty': 0.0, - 'frequency_penalty': 0.0, - 'max_tokens': 10 + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, }, - stop=['How'], + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert len(result.message.content) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), UserPromptMessage( content="what's the weather today in London?", - ) + ), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), PromptMessageTool( - name='get_stock_price', - description='Get the current stock price', + name="get_stock_price", + description="Get the current stock price", parameters={ "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock symbol" - } - }, - "required": [ - "symbol" - ] - } - ) + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), ], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(result, LLMResult) assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = UpstageLargeLanguageModel() result = model.invoke( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], - model_parameters={ - 'temperature': 0.0, - 'max_tokens': 100 - }, + model_parameters={"temperature": 0.0, "max_tokens": 100}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(result, Generator) @@ -189,57 +151,36 @@ def test_get_num_tokens(): model = UpstageLargeLanguageModel() num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ] + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], ) assert num_tokens == 13 num_tokens = model.get_num_tokens( - model='solar-1-mini-chat', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - }, + model="solar-1-mini-chat", + credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_weather', - description='Determine weather in my location', + name="get_weather", + description="Determine weather in my location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ), - ] + ], ) assert num_tokens == 106 diff --git a/api/tests/integration_tests/model_runtime/upstage/test_provider.py b/api/tests/integration_tests/model_runtime/upstage/test_provider.py index c33eef49b2..9d83779aa0 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_provider.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_provider.py @@ -7,17 +7,11 @@ from core.model_runtime.model_providers.upstage.upstage import UpstageProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_validate_provider_credentials(setup_openai_mock): provider = UpstageProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py index 54135a0e74..8c83172fa3 100644 --- a/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/upstage/test_text_embedding.py @@ -8,41 +8,31 @@ from core.model_runtime.model_providers.upstage.text_embedding.text_embedding im from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_validate_credentials(setup_openai_mock): model = UpstageTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': 'invalid_key' - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"} ) model.validate_credentials( - model='solar-embedding-1-large-passage', - credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY') - } + model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")} ) -@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True) def test_invoke_model(setup_openai_mock): model = UpstageTextEmbeddingModel() result = model.invoke( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world", - " ".join(["long_text"] * 100), - " ".join(["another_long_text"] * 100) - ], - user="abc-123" + texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -54,14 +44,11 @@ def test_get_num_tokens(): model = UpstageTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='solar-embedding-1-large-passage', + model="solar-embedding-1-large-passage", credentials={ - 'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'), + "upstage_api_key": os.environ.get("UPSTAGE_API_KEY"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 5 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py index 3b399d604e..f831c063a4 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py @@ -14,26 +14,26 @@ def test_validate_credentials(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - 'base_model_name': 'Doubao-embedding', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + "base_model_name": "Doubao-embedding", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, ) @@ -42,20 +42,17 @@ def test_invoke_model(): model = VolcengineMaaSTextEmbeddingModel() result = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -67,19 +64,16 @@ def test_get_num_tokens(): model = VolcengineMaaSTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'), - 'base_model_name': 'Doubao-embedding', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"), + "base_model_name": "Doubao-embedding", }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py index 63835d0263..8ff9c41404 100644 --- a/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py +++ b/api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py @@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model(): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': 'INVALID', - 'volc_secret_access_key': 'INVALID', - 'endpoint_id': 'INVALID', - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": "INVALID", + "volc_secret_access_key": "INVALID", + "endpoint_id": "INVALID", + }, ) model.validate_credentials( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - } + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + }, ) @@ -40,28 +40,24 @@ def test_invoke_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) @@ -73,28 +69,24 @@ def test_invoke_stream_model(): model = VolcengineMaaSLargeLanguageModel() response = model.invoke( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -102,29 +94,24 @@ def test_invoke_stream_model(): assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len( - chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True def test_get_num_tokens(): model = VolcengineMaaSLargeLanguageModel() response = model.get_num_tokens( - model='NOT IMPORTANT', + model="NOT IMPORTANT", credentials={ - 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com', - 'volc_region': 'cn-beijing', - 'volc_access_key_id': os.environ.get('VOLC_API_KEY'), - 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'), - 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'), - 'base_model_name': 'Skylark2-pro-4k', + "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com", + "volc_region": "cn-beijing", + "volc_access_key_id": os.environ.get("VOLC_API_KEY"), + "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"), + "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"), + "base_model_name": "Skylark2-pro-4k", }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py index d886226cf9..ac38340aec 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -10,13 +10,10 @@ def test_invoke_embedding_v1(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='embedding-v1', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="embedding-v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-en', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-en", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='bge-large-zh', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="bge-large-zh", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) @@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k(): model = WenxinTextEmbeddingModel() response = model.invoke( - model='tao-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - texts=['hello', '你好', 'xxxxx'], - user="abc-123" + model="tao-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + texts=["hello", "你好", "xxxxx"], + user="abc-123", ) assert isinstance(response, TextEmbeddingResult) diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 164e8253d9..e2e58f15e0 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -17,161 +17,125 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': 'invalid_key', - 'secret_key': 'invalid_key' - } + model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"} ) model.validate_credentials( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, ) + def test_invoke_model_ernie_bot(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_turbo(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-turbo', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-turbo", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_8k(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_4(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot-4', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-bot-4", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-3.5-8k', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], + model="ernie-3.5-8k", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -181,63 +145,48 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_model_with_system(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - SystemPromptMessage( - content='你是Kasumi' - ), - UserPromptMessage( - content='你是谁?' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) - assert 'kasumi' in response.message.content.lower() + assert "kasumi" in response.message.content.lower() + def test_invoke_with_search(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.invoke( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'disable_search': True, + "temperature": 0.7, + "top_p": 1.0, + "disable_search": True, }, stop=[], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) @@ -247,25 +196,19 @@ def test_invoke_with_search(): assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True # there should be 对不起、我不能、不支持…… - assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message) + assert "不" in total_message or "抱歉" in total_message or "无法" in total_message + def test_get_num_tokens(): sleep(3) model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( - model='ernie-bot', - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - tools=[] + model="ernie-bot", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + tools=[], ) assert isinstance(response, int) - assert response == 10 \ No newline at end of file + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 8922aa1868..337c3d2a80 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -10,16 +10,8 @@ def test_validate_provider_credentials(): provider = WenxinProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={ - 'api_key': 'hahahaha', - 'secret_key': 'hahahaha' - } - ) + provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"}) provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('WENXIN_API_KEY'), - 'secret_key': os.environ.get('WENXIN_SECRET_KEY') - } + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")} ) diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index f0a5151f3d..8e778d005a 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -8,61 +8,57 @@ from core.model_runtime.model_providers.xinference.text_embedding.text_embedding from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceTextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) model.validate_credentials( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceTextEmbeddingModel() result = model.invoke( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ], - user="abc-123" + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = XinferenceTextEmbeddingModel() num_tokens = model.get_num_tokens( - model='bge-base-en', + model="bge-base-en", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"), }, - texts=[ - "hello", - "world" - ] + texts=["hello", "world"], ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index 47730406de..48d1ae323d 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='aaaaa', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + """ Funtion calling of xinference does not support stream mode currently """ @@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # ) # assert isinstance(response, Generator) - + # call: LLMResultChunk = None # chunks = [] @@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.usage.total_tokens > 0 # assert response.message.tool_calls[0].function.name == 'get_current_weather' -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='alapaca', - credentials={ - 'server_url': '', - 'model_uid': '' - } - ) + model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""}) model.validate_credentials( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], user="abc-123", - stream=False + stream=False, ) assert isinstance(response, LLMResult) assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 -@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() response = model.invoke( - model='alapaca', + model="alapaca", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, - prompt_messages=[ - UserPromptMessage( - content='the United States is' - ) - ], + prompt_messages=[UserPromptMessage(content="the United States is")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, + "temperature": 0.7, + "top_p": 1.0, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = XinferenceAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) - ] + ], ) assert isinstance(num_tokens, int) assert num_tokens == 77 num_tokens = model.get_num_tokens( - model='ChatGLM3', + model="ChatGLM3", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"), }, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) + UserPromptMessage(content="Hello World!"), ], ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index 9012c16a7e..71ac4eef7c 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -8,44 +8,42 @@ from core.model_runtime.model_providers.xinference.rerank.rerank import Xinferen from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_validate_credentials(setup_xinference_mock): model = XinferenceRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='bge-reranker-base', - credentials={ - 'server_url': 'awdawdaw', - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + model="bge-reranker-base", + credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")}, ) model.validate_credentials( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') - } + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), + }, ) -@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceRerankModel() result = model.invoke( - model='bge-reranker-base', + model="bge-reranker-base", credentials={ - 'server_url': os.environ.get('XINFERENCE_SERVER_URL'), - 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID') + "server_url": os.environ.get("XINFERENCE_SERVER_URL"), + "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"), }, query="Who is Kasumi?", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) assert isinstance(result, RerankResult) diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py index 47a5b6cae2..4ca1b86476 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_llm.py @@ -13,41 +13,22 @@ def test_validate_credentials(): model = ZhinaoLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) def test_invoke_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 10 - }, - stop=['How'], + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.5, "max_tokens": 10}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -58,22 +39,12 @@ def test_invoke_stream_model(): model = ZhinaoLargeLanguageModel() response = model.invoke( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.5, - 'max_tokens': 100, - 'seed': 1234 - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -89,18 +60,14 @@ def test_get_num_tokens(): model = ZhinaoLargeLanguageModel() num_tokens = model.get_num_tokens( - model='360gpt2-pro', - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - }, + model="360gpt2-pro", + credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py index 87b0e6c2d9..c22f797919 100644 --- a/api/tests/integration_tests/model_runtime/zhinao/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhinao/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhinaoProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHINAO_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 0f92b50cb0..20380513ea 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -18,41 +18,22 @@ def test_validate_credentials(): model = ZhipuAILargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Who are you?' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, - stop=['How'], + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Who are you?")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, + stop=["How"], stream=False, - user="abc-123" + user="abc-123", ) assert isinstance(response, LLMResult) @@ -63,21 +44,12 @@ def test_invoke_stream_model(): model = ZhipuAILargeLanguageModel() response = model.invoke( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - prompt_messages=[ - UserPromptMessage( - content='Hello World!' - ) - ], - model_parameters={ - 'temperature': 0.9, - 'top_p': 0.7 - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.9, "top_p": 0.7}, stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) @@ -93,63 +65,45 @@ def test_get_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='chatglm_turbo', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="chatglm_turbo", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 14 + def test_get_tools_num_tokens(): model = ZhipuAILargeLanguageModel() num_tokens = model.get_num_tokens( - model='tools', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, + model="tools", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, tools=[ PromptMessageTool( - name='get_current_weather', - description='Get the current weather in a given location', + name="get_current_weather", + description="Get the current weather in a given location", parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "c", - "f" - ] - } + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, }, - "required": [ - "location" - ] - } + "required": ["location"], + }, ) ], prompt_messages=[ SystemPromptMessage( - content='You are a helpful AI assistant.', + content="You are a helpful AI assistant.", ), - UserPromptMessage( - content='Hello World!' - ) - ] + UserPromptMessage(content="Hello World!"), + ], ) assert num_tokens == 88 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 51b9cccf2e..cb5bc0b20a 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -10,12 +10,6 @@ def test_validate_provider_credentials(): provider = ZhipuaiProvider() with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials( - credentials={} - ) + provider.validate_provider_credentials(credentials={}) - provider.validate_provider_credentials( - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 7308c57296..9c97c91ecb 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -11,34 +11,19 @@ def test_validate_credentials(): model = ZhipuAITextEmbeddingModel() with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': 'invalid_key' - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"}) - model.validate_credentials( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - } - ) + model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}) def test_invoke_model(): model = ZhipuAITextEmbeddingModel() result = model.invoke( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ], - user="abc-123" + model="text_embedding", + credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, + texts=["hello", "world"], + user="abc-123", ) assert isinstance(result, TextEmbeddingResult) @@ -50,14 +35,7 @@ def test_get_num_tokens(): model = ZhipuAITextEmbeddingModel() num_tokens = model.get_num_tokens( - model='text_embedding', - credentials={ - 'api_key': os.environ.get('ZHIPUAI_API_KEY') - }, - texts=[ - "hello", - "world" - ] + model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"] ) assert num_tokens == 2 diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 41bb3daeb5..4dfc530010 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,20 +7,17 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ request = httpx.Request( - method, - url, - params=kwargs.get('params'), - headers=kwargs.get('headers'), - cookies=kwargs.get('cookies') + method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies") ) - data = kwargs.get('data', None) - resp = json.dumps(data).encode('utf-8') if data else b'OK' + data = kwargs.get("data", None) + resp = json.dumps(data).encode("utf-8") if data else b"OK" response = httpx.Response( status_code=200, request=request, diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index ba14d365c5..83f4d70ce9 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -10,6 +10,7 @@ todos_data = { "user1": ["Go for a run", "Read a book"], } + class TodosResource(Resource): def get(self, username): todos = todos_data.get(username, []) @@ -32,7 +33,8 @@ class TodosResource(Resource): return {"error": "Invalid todo index"}, 400 -api.add_resource(TodosResource, '/todos/') -if __name__ == '__main__': +api.add_resource(TodosResource, "/todos/") + +if __name__ == "__main__": app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index f6e7b153dd..09729a961e 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -3,37 +3,40 @@ from core.tools.tool.tool import Tool from tests.integration_tests.tools.__mock.http import setup_http_mock tool_bundle = { - 'server_url': 'http://www.example.com/{path_param}', - 'method': 'post', - 'author': '', - 'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'}, - {'in': 'query', 'name': 'query_param'}, - {'in': 'cookie', 'name': 'cookie_param'}, - {'in': 'header', 'name': 'header_param'}, - ], - 'requestBody': { - 'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}} - }, - 'parameters': [] + "server_url": "http://www.example.com/{path_param}", + "method": "post", + "author": "", + "openapi": { + "parameters": [ + {"in": "path", "name": "path_param"}, + {"in": "query", "name": "query_param"}, + {"in": "cookie", "name": "cookie_param"}, + {"in": "header", "name": "header_param"}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}} + }, + }, + "parameters": [], } parameters = { - 'path_param': 'p_param', - 'query_param': 'q_param', - 'cookie_param': 'c_param', - 'header_param': 'h_param', - 'body_param': 'b_param', + "path_param": "p_param", + "query_param": "q_param", + "cookie_param": "c_param", + "header_param": "h_param", + "body_param": "b_param", } def test_api_tool(setup_http_mock): - tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'})) + tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) headers = tool.assembling_request(parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert '/p_param' == response.request.url.path - assert b'query_param=q_param' == response.request.url.query - assert 'h_param' == response.request.headers.get('header_param') - assert 'application/json' == response.request.headers.get('content-type') - assert 'cookie_param=c_param' == response.request.headers.get('cookie') - assert 'b_param' in response.content.decode() + assert "/p_param" == response.request.url.path + assert b"query_param=q_param" == response.request.url.query + assert "h_param" == response.request.headers.get("header_param") + assert "application/json" == response.request.headers.get("content-type") + assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py index 2811bc816d..2dfce749b3 100644 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -7,16 +7,17 @@ provider_names = [provider.identity.name for provider in provider_generator] ToolManager.clear_builtin_providers_cache() provider_generator = ToolManager.list_builtin_providers() -@pytest.mark.parametrize('name', provider_names) + +@pytest.mark.parametrize("name", provider_names) def test_tool_providers(benchmark, name): """ Test that all tool providers can be loaded """ - + def test(generator): try: return next(generator) except StopIteration: return None - - benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) \ No newline at end of file + + benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1) diff --git a/api/tests/integration_tests/utils/parent_class.py b/api/tests/integration_tests/utils/parent_class.py index 39fc95256e..6a6de1cc41 100644 --- a/api/tests/integration_tests/utils/parent_class.py +++ b/api/tests/integration_tests/utils/parent_class.py @@ -3,4 +3,4 @@ class ParentClass: self.name = name def get_name(self): - return self.name \ No newline at end of file + return self.name diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 256c9a911f..7d32f5ae66 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -7,26 +7,26 @@ from tests.integration_tests.utils.parent_class import ParentClass def test_loading_subclass_from_source(): current_path = os.getcwd() module = load_single_subclass_from_source( - module_name='ChildClass', - script_path=os.path.join(current_path, 'child_class.py'), - parent_type=ParentClass) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass + ) + assert module and module.__name__ == "ChildClass" def test_load_import_module_from_source(): current_path = os.getcwd() module = import_module_from_source( - module_name='ChildClass', - py_file_path=os.path.join(current_path, 'child_class.py')) - assert module and module.__name__ == 'ChildClass' + module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py") + ) + assert module and module.__name__ == "ChildClass" def test_lazy_loading_subclass_from_source(): current_path = os.getcwd() clz = load_single_subclass_from_source( - module_name='LazyLoadChildClass', - script_path=os.path.join(current_path, 'lazy_load_class.py'), + module_name="LazyLoadChildClass", + script_path=os.path.join(current_path, "lazy_load_class.py"), parent_type=ParentClass, - use_lazy_loader=True) - instance = clz('dify') - assert instance.get_name() == 'dify' + use_lazy_loader=True, + ) + instance = clz("dify") + assert instance.get_name() == "dify" diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index f8165cba94..571c1e3d44 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,11 +13,15 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - - def VectorDBClient(self, url=None, username='', key='', - read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, - timeout=5, - adapter: HTTPAdapter = None): + def VectorDBClient( + self, + url=None, + username="", + key="", + read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, + timeout=5, + adapter: HTTPAdapter = None, + ): self._conn = None self._read_consistency = read_consistency @@ -26,105 +30,96 @@ class MockTcvectordbClass: Database( conn=self._conn, read_consistency=self._read_consistency, - name='dify', - )] + name="dify", + ) + ] def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: return [] def drop_collection(self, name: str, timeout: Optional[float] = None): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def create_collection( - self, - name: str, - shard: int, - replicas: int, - description: str, - index: Index, - embedding: Embedding = None, - timeout: float = None, + self, + name: str, + shard: int, + replicas: int, + description: str, + index: Index, + embedding: Embedding = None, + timeout: float = None, ) -> Collection: - return Collection(self, name, shard, replicas, description, index, embedding=embedding, - read_consistency=self._read_consistency, timeout=timeout) - - def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: - collection = Collection( + return Collection( self, name, - shard=1, - replicas=2, - description=name, - timeout=timeout + shard, + replicas, + description, + index, + embedding=embedding, + read_consistency=self._read_consistency, + timeout=timeout, ) + + def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: + collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) return collection def collection_upsert( - self, - documents: list[Document], - timeout: Optional[float] = None, - build_index: bool = True, - **kwargs + self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} def collection_search( - self, - vectors: list[list[float]], - filter: Filter = None, - params=None, - retrieve_vector: bool = False, - limit: int = 10, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + vectors: list[list[float]], + filter: Filter = None, + params=None, + retrieve_vector: bool = False, + limit: int = 10, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[list[dict]]: - return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]] + return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] def collection_query( - self, - document_ids: Optional[list] = None, - retrieve_vector: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - filter: Optional[Filter] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, + self, + document_ids: Optional[list] = None, + retrieve_vector: bool = False, + limit: Optional[int] = None, + offset: Optional[int] = None, + filter: Optional[Filter] = None, + output_fields: Optional[list[str]] = None, + timeout: Optional[float] = None, ) -> list[dict]: - return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}] + return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] def collection_delete( - self, - document_ids: list[str] = None, - filter: Filter = None, - timeout: float = None, + self, + document_ids: list[str] = None, + filter: Filter = None, + timeout: float = None, ): - return { - "code": 0, - "msg": "operation success" - } + return {"code": 0, "msg": "operation success"} -MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient) - monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases) - monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection) - monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections) - monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection) - monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection) - monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert) - monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search) - monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query) - monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) + monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) + monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) + monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) + monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) yield diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index d6067af73b..970b98edc3 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -26,6 +26,7 @@ class AnalyticdbVectorTest(AbstractVectorTest): def run_all_tests(self): self.vector.delete() return super().run_all_tests() - + + def test_chroma_vector(setup_mock_redis): - AnalyticdbVectorTest().run_all_tests() \ No newline at end of file + AnalyticdbVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/tests/integration_tests/vdb/chroma/test_chroma.py index 033f9a54da..ac7b5cbda4 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/tests/integration_tests/vdb/chroma/test_chroma.py @@ -14,13 +14,13 @@ class ChromaVectorTest(AbstractVectorTest): self.vector = ChromaVector( collection_name=self.collection_name, config=ChromaConfig( - host='localhost', + host="localhost", port=8000, tenant=chromadb.DEFAULT_TENANT, database=chromadb.DEFAULT_DATABASE, auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", auth_credentials="difyai123456", - ) + ), ) def search_by_full_text(self): diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py index b1c1cc10d9..2a0c1bb038 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -8,16 +8,11 @@ from tests.integration_tests.vdb.test_vector_store import ( class ElasticSearchVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = ElasticSearchVector( index_name=self.collection_name.lower(), - config=ElasticSearchConfig( - host='http://localhost', - port='9200', - username='elastic', - password='elastic' - ), - attributes=self.attributes + config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index 9c0917ef30..7b5f19ea62 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -12,11 +12,11 @@ class MilvusVectorTest(AbstractVectorTest): self.vector = MilvusVector( collection_name=self.collection_name, config=MilvusConfig( - host='localhost', + host="localhost", port=19530, - user='root', - password='Milvus', - ) + user="root", + password="Milvus", + ), ) def search_by_full_text(self): @@ -25,7 +25,7 @@ class MilvusVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/tests/integration_tests/vdb/myscale/test_myscale.py index b6260d549a..55b2fde427 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/tests/integration_tests/vdb/myscale/test_myscale.py @@ -21,7 +21,7 @@ class MyScaleVectorTest(AbstractVectorTest): ) def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index ea1e05da90..a99b81d41e 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -29,54 +29,55 @@ class TestOpenSearchVector: self.example_doc_id = "example_doc_id" self.vector = OpenSearchVector( collection_name=self.collection_name, - config=OpenSearchConfig( - host='localhost', - port=9200, - user='admin', - password='password', - secure=False - ) + config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False), ) self.vector._client = MagicMock() - @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [ - ({ - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] - } - }, 1, "example_doc_id"), - ({ - 'hits': { - 'total': {'value': 0}, - 'hits': [] - } - }, 0, None) - ]) + @pytest.mark.parametrize( + "search_response, expected_length, expected_doc_id", + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): self.vector._client.search.return_value = search_response hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == expected_length if expected_length > 0: - assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id + assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id def test_search_by_vector(self): vector = [0.1] * 128 mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ + "hits": { + "total": {"value": 1}, + "hits": [ { - '_source': { + "_source": { Field.CONTENT_KEY.value: get_example_text(), - Field.METADATA_KEY.value: {"document_id": self.example_doc_id} + Field.METADATA_KEY.value: {"document_id": self.example_doc_id}, }, - '_score': 1.0 + "_score": 1.0, } - ] + ], } } self.vector._client.search.return_value = mock_response @@ -85,53 +86,45 @@ class TestOpenSearchVector: print("Hits by vector:", hits_by_vector) print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits") + print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + assert ( + hits_by_vector[0].metadata["document_id"] == self.example_doc_id + ), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" def test_get_ids_by_metadata_field(self): - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" def test_add_texts(self): - self.vector._client.index.return_value = {'result': 'created'} + self.vector._client.index.return_value = {"result": "created"} doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch('opensearchpy.helpers.bulk') as mock_bulk: + with patch("opensearchpy.helpers.bulk") as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) - mock_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [{'_id': 'mock_id'}] - } - } + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} self.vector._client.search.return_value = mock_response - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 - assert ids[0] == 'mock_id' + assert ids[0] == "mock_id" + @pytest.mark.usefixtures("setup_mock_redis") class TestOpenSearchVectorWithRedis: @@ -141,11 +134,11 @@ class TestOpenSearchVectorWithRedis: def test_search_by_full_text(self): self.tester.setup_method() search_response = { - 'hits': { - 'total': {'value': 1}, - 'hits': [ - {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} - ] + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], } } expected_length = 1 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py index e6ce8aab3d..6b33217d15 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -12,13 +12,13 @@ class PGVectoRSVectorTest(AbstractVectorTest): self.vector = PGVectoRS( collection_name=self.collection_name.lower(), config=PgvectoRSConfig( - host='localhost', + host="localhost", port=5431, - user='postgres', - password='difyai123456', - database='dify', + user="postgres", + password="difyai123456", + database="dify", ), - dim=128 + dim=128, ) def search_by_full_text(self): @@ -27,8 +27,9 @@ class PGVectoRSVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 1 + def test_pgvecot_rs(setup_mock_redis): PGVectoRSVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 34beb25d45..61d9a9e712 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import ( class QdrantVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = QdrantVector( collection_name=self.collection_name, group_id=self.dataset_id, config=QdrantConfig( - endpoint='http://localhost:6333', - api_key='difyai123456', - ) + endpoint="http://localhost:6333", + api_key="difyai123456", + ), ) diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py index 8937fe0ea1..1b9466e27f 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py @@ -7,18 +7,22 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] + class TencentVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.vector = TencentVector("dify", TencentConfig( - url="http://127.0.0.1", - api_key="dify", - timeout=30, - username="dify", - database="dify", - shard=1, - replicas=2, - )) + self.vector = TencentVector( + "dify", + TencentConfig( + url="http://127.0.0.1", + api_key="dify", + timeout=30, + username="dify", + database="dify", + shard=1, + replicas=2, + ), + ) def search_by_vector(self): hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) @@ -28,8 +32,6 @@ class TencentVectorTest(AbstractVectorTest): hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 0 -def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock): + +def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock): TencentVectorTest().run_all_tests() - - - diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index cb35822709..a11cd225b3 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -10,7 +10,7 @@ from models.dataset import Dataset def get_example_text() -> str: - return 'test_text' + return "test_text" def get_example_document(doc_id: str) -> Document: @@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document: "doc_hash": doc_id, "document_id": doc_id, "dataset_id": doc_id, - } + }, ) return doc @@ -45,7 +45,7 @@ class AbstractVectorTest: def __init__(self): self.vector = None self.dataset_id = str(uuid.uuid4()) - self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test' + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] @@ -58,12 +58,12 @@ class AbstractVectorTest: def search_by_vector(self): hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) assert len(hits_by_vector) == 1 - assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id def search_by_full_text(self): hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) assert len(hits_by_full_text) == 1 - assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id def delete_vector(self): self.vector.delete() @@ -76,14 +76,14 @@ class AbstractVectorTest: documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] embeddings = [self.example_embedding] * batch_size self.vector.add_texts(documents=documents, embeddings=embeddings) - return [doc.metadata['doc_id'] for doc in documents] + return [doc.metadata["doc_id"] for doc in documents] def text_exists(self): assert self.vector.text_exists(self.example_doc_id) def get_ids_by_metadata_field(self): with pytest.raises(NotImplementedError): - self.vector.get_ids_by_metadata_field(key='key', value='value') + self.vector.get_ids_by_metadata_field(key="key", value="value") def run_all_tests(self): self.create_vector() diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 18e00dbedd..2a5320c7d5 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -10,15 +10,15 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge @pytest.fixture def tidb_vector(): return TiDBVector( - collection_name='test_collection', + collection_name="test_collection", config=TiDBVectorConfig( host="xxx.eu-central-1.xxx.aws.tidbcloud.com", port="4000", user="xxx.root", password="xxxxxx", database="dify", - program_name="langgenius/dify" - ) + program_name="langgenius/dify", + ), ) @@ -40,7 +40,7 @@ class TiDBVectorTest(AbstractVectorTest): assert len(hits_by_full_text) == 0 def get_ids_by_metadata_field(self): - ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) assert len(ids) == 0 @@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_ @pytest.fixture def mock_session(): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session: + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session: yield mock_session @pytest.fixture def setup_tidbvector_mock(tidb_vector, mock_session): - with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'): - with patch.object(tidb_vector._engine, 'connect'): + with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"): + with patch.object(tidb_vector._engine, "connect"): yield tidb_vector diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py index 3d540cee32..a6f55420d3 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/tests/integration_tests/vdb/weaviate/test_weaviate.py @@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import ( class WeaviateVectorTest(AbstractVectorTest): def __init__(self): super().__init__() - self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = WeaviateVector( collection_name=self.collection_name, config=WeaviateConfig( - endpoint='http://localhost:8080', - api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih', + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", ), - attributes=self.attributes + attributes=self.attributes, ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 51398ccb32..6fb8c86b82 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -7,25 +7,22 @@ from jinja2 import Template from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" + class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], - code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: # invoke directly match language: case CodeLanguage.PYTHON3: - return { - "result": 3 - } + return {"result": 3} case CodeLanguage.JINJA2: - return { - "result": Template(code).render(inputs) - } + return {"result": Template(code).render(inputs)} case _: raise Exception("Language not supported") + @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): if not MOCK: diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index beb5c04009..cfc47bcad4 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -6,38 +6,32 @@ import httpx import pytest from _pytest.monkeypatch import MonkeyPatch -MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' +MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: - def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'], - url: str, **kwargs) -> httpx.Response: + def httpx_request( + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> httpx.Response: """ Mocked httpx.request """ - if url == 'http://404.com': - response = httpx.Response( - status_code=404, - request=httpx.Request(method, url), - content=b'Not Found' - ) + if url == "http://404.com": + response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found") return response # get data, files - data = kwargs.get('data', None) - files = kwargs.get('files', None) + data = kwargs.get("data", None) + files = kwargs.get("files", None) if data is not None: - resp = dumps(data).encode('utf-8') + resp = dumps(data).encode("utf-8") elif files is not None: - resp = dumps(files).encode('utf-8') + resp = dumps(files).encode("utf-8") else: - resp = b'OK' + resp = b"OK" response = httpx.Response( - status_code=200, - request=httpx.Request(method, url), - headers=kwargs.get('headers', {}), - content=resp + status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp ) return response diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index ae6e7ceaa7..44dcf9a10f 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -2,10 +2,10 @@ import pytest from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -CODE_LANGUAGE = 'unsupported_language' +CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): with pytest.raises(CodeExecutionException) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code='', inputs={}) - assert str(e.value) == f'Unsupported language {CODE_LANGUAGE}' + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 0757caba7b..09fcb68cf0 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -9,8 +9,8 @@ CODE_LANGUAGE = CodeLanguage.JAVASCRIPT def test_javascript_plain(): code = 'console.log("Hello World")' - result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result_message == 'Hello World\n' + result_message = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result_message == "Hello World\n" def test_javascript_json(): @@ -18,15 +18,18 @@ def test_javascript_json(): obj = {'Hello': 'World'} console.log(JSON.stringify(obj)) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello":"World"}\n' def test_javascript_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=JavascriptCodeProvider.get_default_code(), - inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} + language=CODE_LANGUAGE, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + def test_javascript_get_runner_script(): runner_script = NodeJsTemplateTransformer.get_runner_script() diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index 425f4cbdd4..94903cf796 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -7,21 +7,24 @@ CODE_LANGUAGE = CodeLanguage.JINJA2 def test_jinja2(): - template = 'Hello {{template}}' - inputs = base64.b64encode(b'{"template": "World"}').decode('utf-8') - code = (Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._code_placeholder, template) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, - preload=Jinja2TemplateTransformer.get_preload_script(), - code=code) - assert result == '<>Hello World<>\n' + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" def test_jinja2_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code='Hello {{template}}', inputs={'template': 'World'}) - assert result == {'result': 'Hello World'} + language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} def test_jinja2_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py index 9d7e86cd68..cbe4a5d335 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -10,8 +10,8 @@ CODE_LANGUAGE = CodeLanguage.PYTHON3 def test_python3_plain(): code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) - assert result == 'Hello World\n' + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) + assert result == "Hello World\n" def test_python3_json(): @@ -19,14 +19,15 @@ def test_python3_json(): import json print(json.dumps({'Hello': 'World'})) """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload='', code=code) + result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) assert result == '{"Hello": "World"}\n' def test_python3_with_code_template(): result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={'arg1': 'Hello', 'arg2': 'World'}) - assert result == {'result': 'HelloWorld'} + language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} + ) + assert result == {"result": "HelloWorld"} def test_python3_get_runner_script(): diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 5c95258520..6f5421e108 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,137 +9,134 @@ from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -CODE_MAX_STRING_LENGTH = int(getenv('CODE_MAX_STRING_LENGTH', '10000')) +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + "id": "1", + "data": { + "outputs": { + "result": { + "type": "number", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] == 3 + assert result.outputs["result"] == 3 assert result.error is None -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code_output_validator(setup_code_executor_mock): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": args1 + args2, } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "result": { "type": "string", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 2) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 2) + # execute node result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == 'Output variable `result` must be a string' + assert result.error == "Output variable `result` must be a string" + def test_execute_code_output_validator_depth(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "string_validator": { "type": "string", @@ -168,29 +165,26 @@ def test_execute_code_output_validator_depth(): "depth": { "type": "number", } - } + }, } - } - } - } + }, + }, + }, }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result @@ -199,14 +193,7 @@ def test_execute_code_output_validator_depth(): "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -218,14 +205,7 @@ def test_execute_code_output_validator_depth(): "string_validator": 1, "number_array_validator": ["1", "2", "3", "3.333"], "string_array_validator": [1, 2, 3], - "object_validator": { - "result": "1", - "depth": { - "depth": { - "depth": "1" - } - } - } + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, } # validate @@ -238,34 +218,20 @@ def test_execute_code_output_validator_depth(): "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", "number_array_validator": [1, 2, 3, 3.333], "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) - + # construct result result = { "number_validator": 1, "string_validator": "1", "number_array_validator": [1, 2, 3, 3.333] * 2000, "string_array_validator": ["1", "2", "3"], - "object_validator": { - "result": 1, - "depth": { - "depth": { - "depth": 1 - } - } - } + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } # validate @@ -274,58 +240,59 @@ def test_execute_code_output_validator_depth(): def test_execute_code_output_object_list(): - code = ''' + code = """ def main(args1: int, args2: int) -> dict: return { "result": { "result": args1 + args2, } } - ''' + """ # trim first 4 spaces at the beginning of each line - code = '\n'.join([line[4:] for line in code.split('\n')]) + code = "\n".join([line[4:] for line in code.split("\n")]) node = CodeNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { + "id": "1", + "data": { "outputs": { "object_list": { "type": "array[object]", }, }, - 'title': '123', - 'variables': [ + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'answer': '123', - 'code_language': 'python3', - 'code': code - } - } + "answer": "123", + "code_language": "python3", + "code": code, + }, + }, ) # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] } # validate @@ -333,13 +300,18 @@ def test_execute_code_output_object_list(): # construct result result = { - "object_list": [{ - "result": 1, - }, { - "result": 2, - }, { - "result": [1, 2, 3], - }, 1] + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] } # validate diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index a1354bd6a5..acb616b325 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,322 +9,337 @@ from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock BASIC_NODE_DATA = { - 'tenant_id': '1', - 'app_id': '1', - 'workflow_id': '1', - 'user_id': '1', - 'user_from': UserFrom.ACCOUNT, - 'invoke_from': InvokeFrom.WEB_APP, + "tenant_id": "1", + "app_id": "1", + "workflow_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.WEB_APP, } # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) -pool.add(['a', 'b123', 'args1'], 1) -pool.add(['a', 'b123', 'args2'], 2) +pool.add(["a", "b123", "args1"], 1) +pool.add(["a", "b123", "args2"], 2) -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_get(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) - - result = node.run(pool) - - data = result.process_data.get('request', '') - - assert '?A=b' in data - assert 'X-Header: 123' in data - - -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_no_auth(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, - }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) - - result = node.run(pool) - - data = result.process_data.get('request', '') - - assert '?A=b' in data - assert 'X-Header: 123' in data - - -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_custom_authorization_header(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'custom', - 'api_key': 'Auth', - 'header': 'X-Auth', + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'X-Header: 123' in data + assert "?A=b" in data + assert "X-Header: 123" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_no_auth(setup_http_mock): + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + }, + **BASIC_NODE_DATA, + ) + + result = node.run(pool) + + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "Auth", + "header": "X-Auth", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + }, + }, + **BASIC_NODE_DATA, + ) + + result = node.run(pool) + + data = result.process_data.get("request", "") + + assert "?A=b" in data + assert "X-Header: 123" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_template(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com/{{#a.b123.args2#}}', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.b123.args2#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", + "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "body": None, }, - 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', - 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', - 'body': None, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert '?A=b' in data - assert 'Template=2' in data - assert 'X-Header: 123' in data - assert 'X-Header2: 2' in data + assert "?A=b" in data + assert "Template=2" in data + assert "X-Header: 123" in data + assert "X-Header2: 2" in data -@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_json(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'json', - 'data': '{"a": "{{#a.b123.args1#}}"}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert '{"a": "1"}' in data - assert 'X-Header: 123' in data + assert "X-Header: 123" in data def test_x_www_form_urlencoded(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'x-www-form-urlencoded', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'a=1&b=2' in data - assert 'X-Header: 123' in data + assert "a=1&b=2" in data + assert "X-Header: 123" in data def test_form_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'form-data', - 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") assert 'form-data; name="a"' in data - assert '1' in data + assert "1" in data assert 'form-data; name="b"' in data - assert '2' in data - assert 'X-Header: 123' in data + assert "2" in data + assert "X-Header: 123" in data def test_none_data(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'post', - 'url': 'http://example.com', - 'authorization': { - 'type': 'api-key', - 'config': { - 'type': 'basic', - 'api_key': 'ak-xxx', - 'header': 'api-key', - } + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "post", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:123", + "params": "A:b", + "body": {"type": "none", "data": "123123123"}, }, - 'headers': 'X-Header:123', - 'params': 'A:b', - 'body': { - 'type': 'none', - 'data': '123123123' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) - data = result.process_data.get('request', '') + data = result.process_data.get("request", "") - assert 'X-Header: 123' in data - assert '123123123' not in data + assert "X-Header: 123" in data + assert "123123123" not in data def test_mock_404(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://404.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://404.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "body": None, + "params": "", + "headers": "X-Header:123", }, - 'body': None, - 'params': '', - 'headers': 'X-Header:123', - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert 404 == resp.get('status_code') - assert 'Not Found' in resp.get('body') + assert 404 == resp.get("status_code") + assert "Not Found" in resp.get("body") def test_multi_colons_parse(setup_http_mock): - node = HttpRequestNode(config={ - 'id': '1', - 'data': { - 'title': 'http', - 'desc': '', - 'method': 'get', - 'url': 'http://example.com', - 'authorization': { - 'type': 'no-auth', - 'config': None, + node = HttpRequestNode( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "no-auth", + "config": None, + }, + "params": "Referer:http://example1.com\nRedirect:http://example2.com", + "headers": "Referer:http://example3.com\nRedirect:http://example4.com", + "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"}, }, - 'params': 'Referer:http://example1.com\nRedirect:http://example2.com', - 'headers': 'Referer:http://example3.com\nRedirect:http://example4.com', - 'body': { - 'type': 'form-data', - 'data': 'Referer:http://example5.com\nRedirect:http://example6.com' - }, - } - }, **BASIC_NODE_DATA) + }, + **BASIC_NODE_DATA, + ) result = node.run(pool) resp = result.outputs - assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request') - assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request') - assert 'http://example3.com' == resp.get('headers').get('referer') + assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request") + assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request") + assert "http://example3.com" == resp.get("headers").get("referer") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 1b27af5af7..6bab83a019 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -23,90 +23,71 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm(setup_openai_mock): node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'prompt_template': [ - { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.' - }, - { - 'role': 'user', - 'text': '{{#sys.query#}}' - } + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather today?', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - provider_instance = ModelProviderFactory().get_provider_instance('openai') + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -118,112 +99,97 @@ def test_execute_llm(setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['text'] is not None - assert result.outputs['usage']['total_tokens'] > 0 + assert result.outputs["text"] is not None + assert result.outputs["usage"]["total_tokens"] > 0 -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): """ Test execute LLM node with jinja2 """ node = LLMNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'llm', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - 'prompt_config': { - 'jinja2_variables': [{ - 'variable': 'sys_query', - 'value_selector': ['sys', 'query'] - }, { - 'variable': 'output', - 'value_selector': ['abc', 'output'] - }] - }, - 'prompt_template': [ + "prompt_template": [ { - 'role': 'system', - 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', - 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', - 'edition_type': 'jinja2' + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", }, { - 'role': 'user', - 'text': '{{#sys.query#}}', - 'jinja2_text': '{{sys_query}}', - 'edition_type': 'basic' - } + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, ], - 'memory': None, - 'context': { - 'enabled': False - }, - 'vision': { - 'enabled': False - } - } - } + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather today?', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['abc', 'output'], 'sunny') + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather today?", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["abc", "output"], "sunny") - credentials = { - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - provider_instance = ModelProviderFactory().get_provider_instance('openai') + provider_instance = ModelProviderFactory().get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, model_type_instance=model_type_instance, ) - model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo") model_config = ModelConfigWithCredentialsEntity( - model='gpt-3.5-turbo', - provider='openai', - mode='chat', + model="gpt-3.5-turbo", + provider="openai", + mode="chat", credentials=credentials, parameters={}, - model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), - provider_model_bundle=provider_model_bundle + model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"), + provider_model_bundle=provider_model_bundle, ) # Mock db.session.close() @@ -235,5 +201,5 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert 'sunny' in json.dumps(result.process_data) - assert 'what\'s the weather today?' in json.dumps(result.process_data) + assert "sunny" in json.dumps(result.process_data) + assert "what's the weather today?" in json.dumps(result.process_data) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e32fa59df3..ca2bae5c53 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -26,29 +26,25 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc def get_mocked_fetch_model_config( - provider: str, model: str, mode: str, + provider: str, + model: str, + mode: str, credentials: dict, ): provider_instance = ModelProviderFactory().get_provider_instance(provider) model_type_instance = provider_instance.get_model_instance(ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( - tenant_id='1', + tenant_id="1", provider=provider_instance.get_provider_schema(), preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, - system_configuration=SystemConfiguration( - enabled=False - ), - custom_configuration=CustomConfiguration( - provider=CustomProviderConfiguration( - credentials=credentials - ) - ), - model_settings=[] + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)), + model_settings=[], ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model) model_config = ModelConfigWithCredentialsEntity( @@ -58,268 +54,268 @@ def get_mocked_fetch_model_config( credentials=credentials, parameters={}, model_schema=model_type_instance.get_model_schema(model), - provider_model_bundle=provider_model_bundle + provider_model_bundle=provider_model_bundle, ) return MagicMock(return_value=(model_instance, model_config)) + def get_mocked_fetch_memory(memory_text: str): class MemoryMock: - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None): + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ): return memory_text return MagicMock(return_value=MemoryMock()) -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ Test function calling for parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'instruction': '', - 'reasoning_mode': 'function_call', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "instruction": "", + "reasoning_mode": "function_call", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None -@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) def test_instructions(setup_openai_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'function_call', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "function_call", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo', mode='chat', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo", + mode="chat", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == 'kawaii' - assert result.outputs.get('__reason') == None + assert result.outputs.get("location") == "kawaii" + assert result.outputs.get("__reason") == None process_data = result.process_data - process_data.get('prompts') + process_data.get("prompts") - for prompt in process_data.get('prompts'): - if prompt.get('role') == 'system': - assert 'what\'s the weather in SF' in prompt.get('text') + for prompt in process_data.get("prompts"): + if prompt.get("role") == "system": + assert "what's the weather in SF" in prompt.get("text") -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor(setup_anthropic_mock): """ Test chat parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': None, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") -@pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) + +@pytest.mark.parametrize("setup_openai_mock", [["completion"]], indirect=True) def test_completion_parameter_extractor(setup_openai_mock): """ Test completion parameter extractor. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='openai', model='gpt-3.5-turbo-instruct', mode='completion', credentials={ - 'openai_api_key': os.environ.get('OPENAI_API_KEY') - } + provider="openai", + model="gpt-3.5-turbo-instruct", + mode="completion", + credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - assert len(result.process_data.get('prompts')) == 1 - assert 'SF' in result.process_data.get('prompts')[0].get('text') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + assert len(result.process_data.get("prompts")) == 1 + assert "SF" in result.process_data.get("prompts")[0].get("text") + def test_extract_json_response(): """ @@ -327,35 +323,30 @@ def test_extract_json_response(): """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'openai', - 'name': 'gpt-3.5-turbo-instruct', - 'mode': 'completion', - 'completion_params': {} + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '{{#sys.query#}}', - 'memory': None, - } - } + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, ) result = node._extract_complete_json_response(""" @@ -366,83 +357,77 @@ def test_extract_json_response(): hello world. """) - assert result['location'] == 'kawaii' + assert result["location"] == "kawaii" -@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) + +@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True) def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): """ Test chat parameter extractor with memory. """ node = ParameterExtractorNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': 'llm', - 'data': { - 'title': '123', - 'type': 'parameter-extractor', - 'model': { - 'provider': 'anthropic', - 'name': 'claude-2', - 'mode': 'chat', - 'completion_params': {} - }, - 'query': ['sys', 'query'], - 'parameters': [{ - 'name': 'location', - 'type': 'string', - 'description': 'location', - 'required': True - }], - 'reasoning_mode': 'prompt', - 'instruction': '', - 'memory': { - 'window': { - 'enabled': True, - 'size': 50 - } - }, - } - } + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": {"provider": "anthropic", "name": "claude-2", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "", + "memory": {"window": {"enabled": True, "size": 50}}, + }, + }, ) node._fetch_model_config = get_mocked_fetch_model_config( - provider='anthropic', model='claude-2', mode='chat', credentials={ - 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') - } + provider="anthropic", + model="claude-2", + mode="chat", + credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory('customized memory') + node._fetch_memory = get_mocked_fetch_memory("customized memory") db.session.close = MagicMock() # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.QUERY: 'what\'s the weather in SF', - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: 'abababa', - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "what's the weather in SF", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={}, + environment_variables=[], + ) result = node.run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs.get('location') == '' - assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' - prompts = result.process_data.get('prompts') + assert result.outputs.get("location") == "" + assert ( + result.outputs.get("__reason") + == "Failed to extract result from function call or text response, using empty result." + ) + prompts = result.process_data.get("prompts") latest_role = None for prompt in prompts: - if prompt.get('role') == 'user': - if '' in prompt.get('text'): - assert '\n{"type": "object"' in prompt.get('text') - elif prompt.get('role') == 'system': - assert 'customized memory' in prompt.get('text') + if prompt.get("role") == "user": + if "" in prompt.get("text"): + assert '\n{"type": "object"' in prompt.get("text") + elif prompt.get("role") == "system": + assert "customized memory" in prompt.get("text") if latest_role is not None: - assert latest_role != prompt.get('role') + assert latest_role != prompt.get("role") - if prompt.get('role') in ['user', 'assistant']: - latest_role = prompt.get('role') + if prompt.get("role") in ["user", "assistant"]: + latest_role = prompt.get("role") diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 781dfbc50f..617b6370c9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -8,42 +8,39 @@ from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock -@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): - code = '''{{args2}}''' + code = """{{args2}}""" node = TemplateTransformNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER, config={ - 'id': '1', - 'data': { - 'title': '123', - 'variables': [ + "id": "1", + "data": { + "title": "123", + "variables": [ { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], + "variable": "args1", + "value_selector": ["1", "123", "args1"], }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } + {"variable": "args2", "value_selector": ["1", "123", "args2"]}, ], - 'template': code, - } - } + "template": code, + }, + }, ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], 1) - pool.add(['1', '123', 'args2'], 3) - + pool.add(["1", "123", "args1"], 1) + pool.add(["1", "123", "args2"], 3) + # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['output'] == '3' + assert result.outputs["output"] == "3" diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 01d62280e8..29c1efa8e7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -7,78 +7,79 @@ from models.workflow import WorkflowNodeExecutionStatus def test_tool_variable_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', '123', 'args1'], '1+1') + pool.add(["1", "123", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'variable', - 'value': ['1', '123', 'args1'], + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] + def test_tool_mixed_invoke(): pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) - pool.add(['1', 'args1'], '1+1') + pool.add(["1", "args1"], "1+1") node = ToolNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", invoke_from=InvokeFrom.WEB_APP, user_from=UserFrom.ACCOUNT, config={ - 'id': '1', - 'data': { - 'title': 'a', - 'desc': 'a', - 'provider_id': 'maths', - 'provider_type': 'builtin', - 'provider_name': 'maths', - 'tool_name': 'eval_expression', - 'tool_label': 'eval_expression', - 'tool_configurations': {}, - 'tool_parameters': { - 'expression': { - 'type': 'mixed', - 'value': '{{#1.args1#}}', + "id": "1", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "mixed", + "value": "{{#1.args1#}}", } - } - } - } + }, + }, + }, ) # execute node result = node.run(pool) - + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert '2' in result.outputs['text'] - assert result.outputs['files'] == [] \ No newline at end of file + assert "2" in result.outputs["text"] + assert result.outputs["files"] == [] diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 949a5a1769..39f313b513 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -6,18 +6,21 @@ from flask import Flask from configs.app_config import DifyConfig -EXAMPLE_ENV_FILENAME = '.env' +EXAMPLE_ENV_FILENAME = ".env" @pytest.fixture def example_env_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME) - file_path.write_text(dedent( - """ + file_path.write_text( + dedent( + """ CONSOLE_API_URL=https://example.com CONSOLE_WEB_URL=https://example.com - """)) + """ + ) + ) return str(file_path) @@ -29,7 +32,7 @@ def test_dify_config_undefined_entry(example_env_file): # entries not defined in app settings with pytest.raises(TypeError): # TypeError: 'AppSettings' object is not subscriptable - assert config['LOG_LEVEL'] == 'INFO' + assert config["LOG_LEVEL"] == "INFO" def test_dify_config(example_env_file): @@ -37,10 +40,10 @@ def test_dify_config(example_env_file): config = DifyConfig(_env_file=example_env_file) # constant values - assert config.COMMIT_SHA == '' + assert config.COMMIT_SHA == "" # default values - assert config.EDITION == 'SELF_HOSTED' + assert config.EDITION == "SELF_HOSTED" assert config.API_COMPRESSION_ENABLED is False assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 @@ -48,36 +51,36 @@ def test_dify_config(example_env_file): # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(example_env_file): - flask_app = Flask('app') + flask_app = Flask("app") # clear system environment variables os.environ.clear() flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings - assert config['LOG_LEVEL'] == 'INFO' - assert config['COMMIT_SHA'] == '' - assert config['EDITION'] == 'SELF_HOSTED' - assert config['API_COMPRESSION_ENABLED'] is False - assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0 - assert config['TESTING'] == False + assert config["LOG_LEVEL"] == "INFO" + assert config["COMMIT_SHA"] == "" + assert config["EDITION"] == "SELF_HOSTED" + assert config["API_COMPRESSION_ENABLED"] is False + assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 + assert config["TESTING"] == False # value from env file - assert config['CONSOLE_API_URL'] == 'https://example.com' + assert config["CONSOLE_API_URL"] == "https://example.com" # fallback to alias choices value as CONSOLE_API_URL - assert config['FILES_URL'] == 'https://example.com' + assert config["FILES_URL"] == "https://example.com" - assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify' - assert config['SQLALCHEMY_ENGINE_OPTIONS'] == { - 'connect_args': { - 'options': '-c timezone=UTC', + assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" + assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { + "connect_args": { + "options": "-c timezone=UTC", }, - 'max_overflow': 10, - 'pool_pre_ping': False, - 'pool_recycle': 3600, - 'pool_size': 30, + "max_overflow": 10, + "pool_pre_ping": False, + "pool_recycle": 3600, + "pool_size": 30, } - assert config['CONSOLE_WEB_URL']=='https://example.com' - assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com'] - assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*'] + assert config["CONSOLE_WEB_URL"] == "https://example.com" + assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] + assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index afd0fa50b5..0824c8e9e9 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -17,31 +17,31 @@ from core.app.segments.exc import VariableError def test_string_variable(): - test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): - test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + test_data = {"value_type": "number", "name": "test_int", "value": 42} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): - test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + test_data = {"value_type": "number", "name": "test_float", "value": 3.14} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): - test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} result = factory.build_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): - test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} + test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} with pytest.raises(VariableError): factory.build_variable_from_mapping(test_data) @@ -49,51 +49,51 @@ def test_invalid_value_type(): def test_build_a_blank_string(): result = factory.build_variable_from_mapping( { - 'value_type': 'string', - 'name': 'blank', - 'value': '', + "value_type": "string", + "name": "blank", + "value": "", } ) assert isinstance(result, StringVariable) - assert result.value == '' + assert result.value == "" def test_build_a_object_variable_with_none_value(): var = factory.build_segment( { - 'key1': None, + "key1": None, } ) assert isinstance(var, ObjectSegment) - assert var.value['key1'] is None + assert var.value["key1"] is None def test_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'object', - 'name': 'test_object', - 'description': 'Description of the variable.', - 'value': { - 'key1': 'text', - 'key2': 2, + "id": str(uuid4()), + "value_type": "object", + "name": "test_object", + "description": "Description of the variable.", + "value": { + "key1": "text", + "key2": 2, }, } variable = factory.build_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) - assert isinstance(variable.value['key1'], str) - assert isinstance(variable.value['key2'], int) + assert isinstance(variable.value["key1"], str) + assert isinstance(variable.value["key2"], int) def test_array_string_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[string]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ - 'text', - 'text', + "id": str(uuid4()), + "value_type": "array[string]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ + "text", + "text", ], } variable = factory.build_variable_from_mapping(mapping) @@ -104,11 +104,11 @@ def test_array_string_variable(): def test_array_number_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[number]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[number]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ 1, 2.0, ], @@ -121,18 +121,18 @@ def test_array_number_variable(): def test_array_object_variable(): mapping = { - 'id': str(uuid4()), - 'value_type': 'array[object]', - 'name': 'test_array', - 'description': 'Description of the variable.', - 'value': [ + "id": str(uuid4()), + "value_type": "array[object]", + "name": "test_array", + "description": "Description of the variable.", + "value": [ { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, { - 'key1': 'text', - 'key2': 1, + "key1": "text", + "key2": 1, }, ], } @@ -140,19 +140,19 @@ def test_array_object_variable(): assert isinstance(variable, ArrayObjectVariable) assert isinstance(variable.value[0], dict) assert isinstance(variable.value[1], dict) - assert isinstance(variable.value[0]['key1'], str) - assert isinstance(variable.value[0]['key2'], int) - assert isinstance(variable.value[1]['key1'], str) - assert isinstance(variable.value[1]['key2'], int) + assert isinstance(variable.value[0]["key1"], str) + assert isinstance(variable.value[0]["key2"], int) + assert isinstance(variable.value[1]["key1"], str) + assert isinstance(variable.value[1]["key2"], int) def test_variable_cannot_large_than_5_kb(): with pytest.raises(VariableError): factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'value_type': 'string', - 'name': 'test_text', - 'value': 'a' * 1024 * 6, + "id": str(uuid4()), + "value_type": "string", + "name": "test_text", + "value": "a" * 1024 * 6, } ) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 50d991316d..7cc339d212 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -7,20 +7,20 @@ from core.workflow.enums import SystemVariableKey def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariableKey('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[ - SecretVariable(name='secret_key', value='fake-secret-key'), + SecretVariable(name="secret_key", value="fake-secret-key"), ], ) - variable_pool.add(('node_id', 'custom_query'), 'fake-user-query') + variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( - 'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.' + "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.' + assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key." assert ( segments_group.log == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." @@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group(): user_inputs={}, environment_variables=[], ) - template = 'Hello, world!' + template = "Hello, world!" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'Hello, world!' - assert segments_group.log == 'Hello, world!' + assert segments_group.text == "Hello, world!" + assert segments_group.log == "Hello, world!" def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariableKey('user_id'): 'fake-user-id', + SystemVariableKey("user_id"): "fake-user-id", }, user_inputs={}, environment_variables=[], ) - template = '{{#sys.user_id#}}' + template = "{{#sys.user_id#}}" segments_group = parser.convert_template(template=template, variable_pool=variable_pool) - assert segments_group.text == 'fake-user-id' - assert segments_group.log == 'fake-user-id' - assert segments_group.value == [StringSegment(value='fake-user-id')] + assert segments_group.text == "fake-user-id" + assert segments_group.log == "fake-user-id" + assert segments_group.value == [StringSegment(value="fake-user-id")] diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 1f45c15f87..b3f0ae626c 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -13,60 +13,60 @@ from core.app.segments import ( def test_frozen_variables(): - var = StringVariable(name='text', value='text') + var = StringVariable(name="text", value="text") with pytest.raises(ValidationError): - var.value = 'new value' + var.value = "new value" - int_var = IntegerVariable(name='integer', value=42) + int_var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): int_var.value = 100 - float_var = FloatVariable(name='float', value=3.14) + float_var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): float_var.value = 2.718 - secret_var = SecretVariable(name='secret', value='secret_value') + secret_var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): - secret_var.value = 'new_secret_value' + secret_var.value = "new_secret_value" def test_variable_value_type_immutable(): with pytest.raises(ValidationError): - StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text') + StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text") with pytest.raises(ValidationError): - StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) + StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"}) - var = IntegerVariable(name='integer', value=42) + var = IntegerVariable(name="integer", value=42) with pytest.raises(ValidationError): IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) with pytest.raises(ValidationError): FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) - var = SecretVariable(name='secret', value='secret_value') + var = SecretVariable(name="secret", value="secret_value") with pytest.raises(ValidationError): SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value) def test_object_variable_to_object(): var = ObjectVariable( - name='object', + name="object", value={ - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': ['value5_1', 42, {}], + "key2": ["value5_1", 42, {}], }, ) assert var.to_object() == { - 'key1': { - 'key2': 'value2', + "key1": { + "key2": "value2", }, - 'key2': [ - 'value5_1', + "key2": [ + "value5_1", 42, {}, ], @@ -74,11 +74,11 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name='text', value='text') - assert var.to_object() == 'text' - var = IntegerVariable(name='integer', value=42) + var = StringVariable(name="text", value="text") + assert var.to_object() == "text" + var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 - var = FloatVariable(name='float', value=3.14) + var = FloatVariable(name="float", value=3.14) assert var.to_object() == 3.14 - var = SecretVariable(name='secret', value='secret_value') - assert var.to_object() == 'secret_value' + var = SecretVariable(name="secret", value="secret_value") + assert var.to_object() == "secret_value" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d917bb1003..7a0bc70c63 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, patch from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request -@patch('httpx.request') +@patch("httpx.request") def test_successful_request(mock_request): mock_response = MagicMock() mock_response.status_code = 200 mock_request.return_value = mock_response - response = make_request('GET', 'http://example.com') + response = make_request("GET", "http://example.com") assert response.status_code == 200 -@patch('httpx.request') +@patch("httpx.request") def test_retry_exceed_max_retries(mock_request): mock_response = MagicMock() mock_response.status_code = 500 @@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request): mock_request.side_effect = side_effects try: - make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) raise AssertionError("Expected Exception not raised") except Exception as e: assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" -@patch('httpx.request') +@patch("httpx.request") def test_retry_logic_success(mock_request): side_effects = [] @@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request): mock_request.side_effect = side_effects - response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES) assert response.status_code == 200 assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 - assert mock_request.call_args_list[0][1].get('method') == 'GET' + assert mock_request.call_args_list[0][1].get("method") == "GET" diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py index 68334fde82..5b159b49b6 100644 --- a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -21,18 +21,18 @@ def test_max_chunks(): def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: return _MockTextEmbedding() - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) max_chunks = embedding_model._get_max_chunks(model, credentials) embedding_model._create_text_embedding = _create_text_embedding - texts = ['0123456789' for i in range(0, max_chunks * 2)] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + texts = ["0123456789" for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert len(result.embeddings) == max_chunks * 2 @@ -41,16 +41,16 @@ def test_context_size(): return GPT2Tokenizer.get_num_tokens(text) def mock_text(token_size: int) -> str: - _text = "".join(['0' for i in range(token_size)]) + _text = "".join(["0" for i in range(token_size)]) num_tokens = get_num_tokens_by_gpt2(_text) ratio = int(np.floor(len(_text) / num_tokens)) m_text = "".join([_text for i in range(ratio)]) return m_text - model = 'embedding-v1' + model = "embedding-v1" credentials = { - 'api_key': 'xxxx', - 'secret_key': 'yyyy', + "api_key": "xxxx", + "secret_key": "yyyy", } embedding_model = WenxinTextEmbeddingModel() context_size = embedding_model._get_context_size(model, credentials) @@ -71,5 +71,5 @@ def test_context_size(): assert get_num_tokens_by_gpt2(text) == context_size * 2 texts = [text] - result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test") assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index d24cd4aae9..24bbde6d4e 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -14,39 +14,24 @@ from models.model import Conversation def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_config = CompletionModelPromptTemplate( - text=prompt_template - ) + prompt_template_config = CompletionModelPromptTemplate(text=prompt_template) memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix( - user="Human", - assistant="Assistant" - ), - window=MemoryConfig.WindowConfig( - enabled=False - ) + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), ) - inputs = { - "name": "John" - } + inputs = {"name": "John"} files = [] context = "I am superman." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages(): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 1 - assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ - "#context#": context, - "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " - f"{prompt.content}" for prompt in history_prompt_messages]), - **inputs, - }) + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format( + { + "#context#": context, + "#histories#": "\n".join( + [ + f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}" + for prompt in history_prompt_messages + ] + ), + **inputs, + } + ) def test__get_chat_model_prompt_messages(get_chat_model_args): @@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): files = [] query = "Hi2." - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi1."), - AssistantPromptMessage(content="Hello1!") - ] + history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = AdvancedPromptTransform() @@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): context=context, memory_config=memory_config, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert prompt_messages[5].content == query @@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): @@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg image_config={ "detail": "high", } - ) + ), ) ] @@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg context=context, memory_config=None, memory=None, - model_config=model_config_mock + model_config=model_config_mock, ) assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM - assert prompt_messages[0].content == PromptTemplateParser( - template=messages[0].text - ).format({**inputs, "#context#": context}) + assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format( + {**inputs, "#context#": context} + ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 assert prompt_messages[3].content[1].data == files[0].url @@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg @pytest.fixture def get_chat_model_args(): model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) prompt_messages = [ ChatModelMessage( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM ), - ChatModelMessage( - text="Hi.", - role=PromptMessageRole.USER - ), - ChatModelMessage( - text="Hello!", - role=PromptMessageRole.ASSISTANT - ) + ChatModelMessage(text="Hi.", role=PromptMessageRole.USER), + ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT), ] - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "I am superman." diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 9de268d762..0fd176e65d 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -18,27 +18,28 @@ from models.model import Conversation def test_get_prompt(): prompt_messages = [ - SystemPromptMessage(content='System Template'), - UserPromptMessage(content='User Query'), + SystemPromptMessage(content="System Template"), + UserPromptMessage(content="User Query"), ] history_messages = [ - SystemPromptMessage(content='System Prompt 1'), - UserPromptMessage(content='User Prompt 1'), - AssistantPromptMessage(content='Assistant Thought 1'), - ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), - ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), - SystemPromptMessage(content='System Prompt 2'), - UserPromptMessage(content='User Prompt 2'), - AssistantPromptMessage(content='Assistant Thought 2'), - ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), - ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), - UserPromptMessage(content='User Prompt 3'), - AssistantPromptMessage(content='Assistant Thought 3'), + SystemPromptMessage(content="System Prompt 1"), + UserPromptMessage(content="User Prompt 1"), + AssistantPromptMessage(content="Assistant Thought 1"), + ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"), + ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"), + SystemPromptMessage(content="System Prompt 2"), + UserPromptMessage(content="User Prompt 2"), + AssistantPromptMessage(content="Assistant Thought 2"), + ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"), + ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"), + UserPromptMessage(content="User Prompt 3"), + AssistantPromptMessage(content="Assistant Thought 3"), ] # use message number instead of token for testing def side_effect_get_num_tokens(*args): return len(args[2]) + large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) @@ -46,20 +47,17 @@ def test_get_prompt(): provider_model_bundle_mock.model_type_instance = large_language_model_mock model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.model = 'openai' + model_config_mock.model = "openai" model_config_mock.credentials = {} model_config_mock.provider_model_bundle = provider_model_bundle_mock - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) transform = AgentHistoryPromptTransform( model_config=model_config_mock, prompt_messages=prompt_messages, history_messages=history_messages, - memory=memory + memory=memory, ) max_token_limit = 5 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 2bcc6f4292..89c14463bb 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform def test__calculate_rest_token(): model_schema_mock = MagicMock(spec=AIModelEntity) parameter_rule_mock = MagicMock(spec=ParameterRule) - parameter_rule_mock.name = 'max_tokens' - model_schema_mock.parameter_rules = [ - parameter_rule_mock - ] - model_schema_mock.model_properties = { - ModelPropertyKey.CONTEXT_SIZE: 62 - } + parameter_rule_mock.name = "max_tokens" + model_schema_mock.parameter_rules = [parameter_rule_mock] + model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 provider_mock = MagicMock(spec=ProviderEntity) - provider_mock.provider = 'openai' + provider_mock.provider = "openai" provider_configuration_mock = MagicMock(spec=ProviderConfiguration) provider_configuration_mock.provider = provider_mock @@ -35,11 +31,9 @@ def test__calculate_rest_token(): provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.model = 'gpt-4' + model_config_mock.model = "gpt-4" model_config_mock.credentials = {} - model_config_mock.parameters = { - 'max_tokens': 50 - } + model_config_mock.parameters = {"max_tokens": 50} model_config_mock.model_schema = model_schema_mock model_config_mock.provider_model_bundle = provider_model_bundle_mock @@ -49,8 +43,10 @@ def test__calculate_rest_token(): rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) # Validate based on the mock configuration and expected logic - expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] - - model_config_mock.parameters['max_tokens'] - - large_language_model_mock.get_num_tokens.return_value) + expected_rest_tokens = ( + model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters["max_tokens"] + - large_language_model_mock.get_num_tokens.return_value + ) assert rest_tokens == expected_rest_tokens assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 6d6363610b..c32fc2bc34 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_baichuan_chat_app_prompt_template_with_pcqm(): @@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm(): query_in_prompt=True, with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['histories_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + + pre_prompt + + "\n" + + prompt_rules["histories_prompt"] + + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"] def test_get_common_completion_app_prompt_template_with_pcq(): @@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_baichuan_completion_app_prompt_template_with_pcq(): @@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq(): query_in_prompt=True, with_memory_prompt=False, ) - print(prompt_template['prompt_template'].template) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + pre_prompt + '\n' - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + print(prompt_template["prompt_template"].template) + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_q(): @@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] - assert prompt_template['special_variable_keys'] == ['#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"] + assert prompt_template["special_variable_keys"] == ["#query#"] def test_get_common_chat_app_prompt_template_with_cq(): @@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq(): query_in_prompt=True, with_memory_prompt=False, ) - prompt_rules = prompt_template['prompt_rules'] - assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] - + prompt_rules['query_prompt']) - assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + prompt_rules = prompt_template["prompt_rules"] + assert prompt_template["prompt_template"].template == ( + prompt_rules["context_prompt"] + prompt_rules["query_prompt"] + ) + assert prompt_template["special_variable_keys"] == ["#context#", "#query#"] def test_get_common_chat_app_prompt_template_with_p(): @@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p(): query_in_prompt=False, with_memory_prompt=False, ) - assert prompt_template['prompt_template'].template == pre_prompt + '\n' - assert prompt_template['custom_variable_keys'] == ['name'] - assert prompt_template['special_variable_keys'] == [] + assert prompt_template["prompt_template"].template == pre_prompt + "\n" + assert prompt_template["custom_variable_keys"] == ["name"] + assert prompt_template["special_variable_keys"] == [] def test__get_chat_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-4' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-4" memory_mock = MagicMock(spec=TokenBufferMemory) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory_mock.get_history_prompt_messages.return_value = history_prompt_messages prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( @@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages(): files=[], context=context, memory=memory_mock, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages(): with_memory_prompt=False, ) - full_inputs = {**inputs, '#context#': context} - real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + full_inputs = {**inputs, "#context#": context} + real_system_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt @@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) - model_config_mock.provider = 'openai' - model_config_mock.model = 'gpt-3.5-turbo-instruct' + model_config_mock.provider = "openai" + model_config_mock.model = "gpt-3.5-turbo-instruct" - memory = TokenBufferMemory( - conversation=Conversation(), - model_instance=model_config_mock - ) + memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock) - history_prompt_messages = [ - UserPromptMessage(content="Hi"), - AssistantPromptMessage(content="Hello") - ] + history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")] memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) prompt_transform = SimplePromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." - inputs = { - "name": "John" - } + inputs = {"name": "John"} context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( @@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages(): files=[], context=context, memory=memory, - model_config=model_config_mock + model_config=model_config_mock, ) prompt_template = prompt_transform.get_prompt_template( @@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages(): with_memory_prompt=True, ) - prompt_rules = prompt_template['prompt_rules'] - full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( - max_token_limit=2000, - human_prefix=prompt_rules.get("human_prefix", "Human"), - ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") - )} - real_prompt = prompt_template['prompt_template'].format(full_inputs) + prompt_rules = prompt_template["prompt_rules"] + full_inputs = { + **inputs, + "#context#": context, + "#query#": query, + "#histories#": memory.get_history_prompt_text( + max_token_limit=2000, + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), + ), + } + real_prompt = prompt_template["prompt_template"].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_rules.get('stops') + assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 9e43b23658..8d735cae86 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -5,20 +5,15 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig def test_default_value(): - valid_config = { - 'host': 'localhost', - 'port': 19530, - 'user': 'root', - 'password': 'Milvus' - } + valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: MilvusConfig(**config) - assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required' + assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" config = MilvusConfig(**valid_config) assert config.secure is False - assert config.database == 'default' + assert config.database == "default" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index a8bba11e16..d5a1d8f436 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -9,19 +9,17 @@ from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_resp def test_firecrawl_web_extractor_crawl_mode(mocker): url = "https://firecrawl.dev" - api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' - base_url = 'https://api.firecrawl.dev' - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=base_url) + api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" + base_url = "https://api.firecrawl.dev" + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "maxDepth": 1, "limit": 1, - 'returnOnlyUrls': False, - + "returnOnlyUrls": False, } } mocked_firecrawl = { diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index b231fe479b..eea584a2f8 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -8,11 +8,8 @@ page_id = "page1" extractor = notion_extractor.NotionExtractor( - notion_workspace_id='x', - notion_obj_id='x', - notion_page_type='page', - tenant_id='x', - notion_access_token='x') + notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" +) def _generate_page(page_title: str): @@ -21,16 +18,10 @@ def _generate_page(page_title: str): "id": page_id, "properties": { "Page": { - "type": "title", - "title": [ - { - "type": "text", - "text": {"content": page_title}, - "plain_text": page_title - } - ] + "type": "title", + "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], } - } + }, } @@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str): return { "object": "block", "id": block_id, - "parent": { - "type": "page_id", - "page_id": page_id - }, + "parent": {"type": "page_id", "page_id": page_id}, "type": block_type, "has_children": False, block_type: { @@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str): { "type": "text", "text": {"content": block_text}, - "plain_text": block_text, - }] - } - } + "plain_text": block_text, + } + ] + }, + } def _mock_response(data): @@ -63,7 +52,7 @@ def _mock_response(data): def _remove_multiple_new_lines(text): - while '\n\n' in text: + while "\n\n" in text: text = text.replace("\n\n", "\n") return text.strip() @@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text): def test_notion_page(mocker): texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]) - ], - "next_cursor": None + "object": "list", + "results": [ + _generate_block("b1", "heading_1", texts[0]), + _generate_block("b2", "heading_2", texts[1]), + _generate_block("b3", "paragraph", texts[2]), + _generate_block("b4", "heading_3", texts[3]), + ], + "next_cursor": None, } mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page)) page_docs = extractor._load_data_as_documents(page_id, "page") assert len(page_docs) == 1 content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1' + assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" def test_notion_database(mocker): @@ -93,10 +82,10 @@ def test_notion_database(mocker): mocked_notion_database = { "object": "list", "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None + "next_cursor": None, } mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database)) database_docs = extractor._load_data_as_documents(database_id, "database") assert len(database_docs) == 1 content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == '\n'.join([f'Page:{i}' for i in page_title_list]) + assert content == "\n".join([f"Page:{i}" for i in page_title_list]) diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 3024a54a4d..2808b5b0fa 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -10,36 +10,24 @@ from core.model_runtime.entities.model_entities import ModelType @pytest.fixture def lb_model_manager(): load_balancing_configs = [ - ModelLoadBalancingConfiguration( - id='id1', - name='__inherit__', - credentials={} - ), - ModelLoadBalancingConfiguration( - id='id2', - name='first', - credentials={"openai_api_key": "fake_key"} - ), - ModelLoadBalancingConfiguration( - id='id3', - name='second', - credentials={"openai_api_key": "fake_key"} - ) + ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}), + ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}), + ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}), ] lb_model_manager = LBModelManager( - tenant_id='tenant_id', - provider='openai', + tenant_id="tenant_id", + provider="openai", model_type=ModelType.LLM, - model='gpt-4', + model="gpt-4", load_balancing_configs=load_balancing_configs, - managed_credentials={"openai_api_key": "fake_key"} + managed_credentials={"openai_api_key": "fake_key"}, ) lb_model_manager.cooldown = MagicMock(return_value=None) def is_cooldown(config: ModelLoadBalancingConfiguration): - if config.id == 'id1': + if config.id == "id1": return True return False @@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager): assert lb_model_manager.in_cooldown(config3) is False start_index = 0 + def incr(key): nonlocal start_index start_index += 1 return start_index - mocker.patch('redis.Redis.incr', side_effect=incr) - mocker.patch('redis.Redis.set', return_value=None) - mocker.patch('redis.Redis.expire', return_value=None) + mocker.patch("redis.Redis.incr", side_effect=incr) + mocker.patch("redis.Redis.set", return_value=None) + mocker.patch("redis.Redis.expire", return_value=None) config = lb_model_manager.fetch_next() assert config == config2 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 072b6f100f..2f4214a580 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -11,62 +11,62 @@ def test__to_model_settings(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] - load_balancing_model_configs = [ - LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', - encrypted_config=None, - enabled=True - ), - LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', - encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, ) ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 2 - assert result[0].load_balancing_configs[0].name == '__inherit__' - assert result[0].load_balancing_configs[1].name == 'first' + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" def test__to_model_settings_only_one_lb(mocker): @@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=True - )] + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] load_balancing_model_configs = [ LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", encrypted_config=None, - enabled=True + enabled=True, ) ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 @@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker): provider_entity = None for provider in provider_entities: - if provider.provider == 'openai': + if provider.provider == "openai": provider_entity = provider # Mocking the inputs - provider_model_settings = [ProviderModelSetting( - id='id', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - enabled=True, - load_balancing_enabled=False - )] - load_balancing_model_configs = [ - LoadBalancingModelConfig( - id='id1', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='__inherit__', - encrypted_config=None, - enabled=True - ), - LoadBalancingModelConfig( - id='id2', - tenant_id='tenant_id', - provider_name='openai', - model_name='gpt-4', - model_type='text-generation', - name='first', - encrypted_config='{"openai_api_key": "fake_key"}', - enabled=True + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, ) ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] - mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) provider_manager = ProviderManager() # Running the method - result = provider_manager._to_model_settings( - provider_entity, - provider_model_settings, - load_balancing_model_configs - ) + result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs) # Asserting that the result is as expected assert len(result) == 1 assert isinstance(result[0], ModelSettings) - assert result[0].model == 'gpt-4' + assert result[0].model == "gpt-4" assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py index 9addeeadca..279a6cdbc3 100644 --- a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -5,52 +5,52 @@ from core.tools.utils.tool_parameter_converter import ToolParameterConverter def test_get_parameter_type(): - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' - assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean" + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number" with pytest.raises(ValueError): - ToolParameterConverter.get_parameter_type('unsupported_type') + ToolParameterConverter.get_parameter_type("unsupported_type") def test_cast_parameter_by_type(): # string - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == "" # secret input - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == "" # select - assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' - assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' - assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' - assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test" + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1" + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0" + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == "" # boolean - true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"] for value in true_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True - false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""] for value in false_values: assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False # number - assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 - assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 - assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None # unknown - assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' - assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1" + assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1" assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 44b7c85256..8020674ee6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -11,29 +11,30 @@ from models.workflow import WorkflowNodeExecutionStatus def test_execute_answer(): node = AnswerNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'answer', - 'data': { - 'title': '123', - 'type': 'answer', - 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' - } - } + "id": "answer", + "data": { + "title": "123", + "type": "answer", + "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'weather'], 'sunny') - pool.add(['llm', 'text'], 'You are a helpful AI.') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "weather"], "sunny") + pool.add(["llm", "text"], "You are a helpful AI.") # Mock db.session.close() db.session.close = MagicMock() @@ -42,4 +43,4 @@ def test_execute_answer(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 87ebcb34e6..9535bc2186 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -11,134 +11,81 @@ from models.workflow import WorkflowNodeExecutionStatus def test_execute_if_else_result_true(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'and', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "and", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", }, + {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"}, { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'contains'], - 'value': 'ab' + "comparison_operator": "not contains", + "variable_selector": ["start", "not_contains"], + "value": "ab", }, + {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"}, + {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"}, + {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"}, + {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"}, + {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"}, + {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"}, + {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"}, + {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"}, + {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"}, + {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"}, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'not_contains'], - 'value': 'ab' + "comparison_operator": "≥", + "variable_selector": ["start", "greater_than_or_equal"], + "value": "22", }, - { - 'comparison_operator': 'start with', - 'variable_selector': ['start', 'start_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'end with', - 'variable_selector': ['start', 'end_with'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is', - 'variable_selector': ['start', 'is'], - 'value': 'ab' - }, - { - 'comparison_operator': 'is not', - 'variable_selector': ['start', 'is_not'], - 'value': 'ab' - }, - { - 'comparison_operator': 'empty', - 'variable_selector': ['start', 'empty'], - 'value': 'ab' - }, - { - 'comparison_operator': 'not empty', - 'variable_selector': ['start', 'not_empty'], - 'value': 'ab' - }, - { - 'comparison_operator': '=', - 'variable_selector': ['start', 'equals'], - 'value': '22' - }, - { - 'comparison_operator': '≠', - 'variable_selector': ['start', 'not_equals'], - 'value': '22' - }, - { - 'comparison_operator': '>', - 'variable_selector': ['start', 'greater_than'], - 'value': '22' - }, - { - 'comparison_operator': '<', - 'variable_selector': ['start', 'less_than'], - 'value': '22' - }, - { - 'comparison_operator': '≥', - 'variable_selector': ['start', 'greater_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': '≤', - 'variable_selector': ['start', 'less_than_or_equal'], - 'value': '22' - }, - { - 'comparison_operator': 'null', - 'variable_selector': ['start', 'null'] - }, - { - 'comparison_operator': 'not null', - 'variable_selector': ['start', 'not_null'] - }, - ] - } - } + {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"}, + {"comparison_operator": "null", "variable_selector": ["start", "null"]}, + {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ac', 'def']) - pool.add(['start', 'contains'], 'cabcde') - pool.add(['start', 'not_contains'], 'zacde') - pool.add(['start', 'start_with'], 'abc') - pool.add(['start', 'end_with'], 'zzab') - pool.add(['start', 'is'], 'ab') - pool.add(['start', 'is_not'], 'aab') - pool.add(['start', 'empty'], '') - pool.add(['start', 'not_empty'], 'aaa') - pool.add(['start', 'equals'], 22) - pool.add(['start', 'not_equals'], 23) - pool.add(['start', 'greater_than'], 23) - pool.add(['start', 'less_than'], 21) - pool.add(['start', 'greater_than_or_equal'], 22) - pool.add(['start', 'less_than_or_equal'], 21) - pool.add(['start', 'not_null'], '1212') + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["ab", "def"]) + pool.add(["start", "array_not_contains"], ["ac", "def"]) + pool.add(["start", "contains"], "cabcde") + pool.add(["start", "not_contains"], "zacde") + pool.add(["start", "start_with"], "abc") + pool.add(["start", "end_with"], "zzab") + pool.add(["start", "is"], "ab") + pool.add(["start", "is_not"], "aab") + pool.add(["start", "empty"], "") + pool.add(["start", "not_empty"], "aaa") + pool.add(["start", "equals"], 22) + pool.add(["start", "not_equals"], 23) + pool.add(["start", "greater_than"], 23) + pool.add(["start", "less_than"], 21) + pool.add(["start", "greater_than_or_equal"], 22) + pool.add(["start", "less_than_or_equal"], 21) + pool.add(["start", "not_null"], "1212") # Mock db.session.close() db.session.close = MagicMock() @@ -147,46 +94,47 @@ def test_execute_if_else_result_true(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is True + assert result.outputs["result"] is True def test_execute_if_else_result_false(): node = IfElseNode( - tenant_id='1', - app_id='1', - workflow_id='1', - user_id='1', + tenant_id="1", + app_id="1", + workflow_id="1", + user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'if-else', - 'data': { - 'title': '123', - 'type': 'if-else', - 'logical_operator': 'or', - 'conditions': [ + "id": "if-else", + "data": { + "title": "123", + "type": "if-else", + "logical_operator": "or", + "conditions": [ { - 'comparison_operator': 'contains', - 'variable_selector': ['start', 'array_contains'], - 'value': 'ab' + "comparison_operator": "contains", + "variable_selector": ["start", "array_contains"], + "value": "ab", }, { - 'comparison_operator': 'not contains', - 'variable_selector': ['start', 'array_not_contains'], - 'value': 'ab' - } - ] - } - } + "comparison_operator": "not contains", + "variable_selector": ["start", "array_not_contains"], + "value": "ab", + }, + ], + }, + }, ) # construct variable pool - pool = VariablePool(system_variables={ - SystemVariableKey.FILES: [], - SystemVariableKey.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['1ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ab', 'def']) + pool = VariablePool( + system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + user_inputs={}, + environment_variables=[], + ) + pool.add(["start", "array_contains"], ["1ab", "def"]) + pool.add(["start", "array_not_contains"], ["ab", "def"]) # Mock db.session.close() db.session.close = MagicMock() @@ -195,4 +143,4 @@ def test_execute_if_else_result_false(): result = node._run(pool) assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs['result'] is False + assert result.outputs["result"] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 5df8c1b763..e26c7df642 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -8,41 +8,41 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode -DEFAULT_NODE_ID = 'node_id' +DEFAULT_NODE_ID = "node_id" def test_overwrite_string_variable(): conversation_variable = StringVariable( id=str(uuid4()), - name='test_conversation_variable', - value='the first value', + name="test_conversation_variable", + value="the first value", ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.OVER_WRITE.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.OVER_WRITE.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -52,48 +52,48 @@ def test_overwrite_string_variable(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.value == 'the second value' - assert got.to_object() == 'the second value' + assert got.value == "the second value" + assert got.to_object() == "the second value" def test_append_variable_to_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) input_variable = StringVariable( id=str(uuid4()), - name='test_string_variable', - value='the second value', + name="test_string_variable", + value="the second value", ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.APPEND.value, - 'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.APPEND.value, + "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -103,41 +103,41 @@ def test_append_variable_to_array(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: node.run(variable_pool) mock_run.assert_called_once() - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None - assert got.to_object() == ['the first value', 'the second value'] + assert got.to_object() == ["the first value", "the second value"] def test_clear_array(): conversation_variable = ArrayStringVariable( id=str(uuid4()), - name='test_conversation_variable', - value=['the first value'], + name="test_conversation_variable", + value=["the first value"], ) node = VariableAssignerNode( - tenant_id='tenant_id', - app_id='app_id', - workflow_id='workflow_id', - user_id='user_id', + tenant_id="tenant_id", + app_id="app_id", + workflow_id="workflow_id", + user_id="user_id", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, config={ - 'id': 'node_id', - 'data': { - 'assigned_variable_selector': ['conversation', conversation_variable.name], - 'write_mode': WriteMode.CLEAR.value, - 'input_variable_selector': [], + "id": "node_id", + "data": { + "assigned_variable_selector": ["conversation", conversation_variable.name], + "write_mode": WriteMode.CLEAR.value, + "input_variable_selector": [], }, }, ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -145,6 +145,6 @@ def test_clear_array(): node.run(variable_pool) - got = variable_pool.get(['conversation', conversation_variable.name]) + got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] diff --git a/api/tests/unit_tests/libs/test_pandas.py b/api/tests/unit_tests/libs/test_pandas.py index bbc372ed61..21c2f0781d 100644 --- a/api/tests/unit_tests/libs/test_pandas.py +++ b/api/tests/unit_tests/libs/test_pandas.py @@ -3,50 +3,46 @@ import pandas as pd def test_pandas_csv(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to csv file - csv_file_path = tmp_path.joinpath('example.csv') + csv_file_path = tmp_path.joinpath("example.csv") df1.to_csv(csv_file_path, index=False) # read from csv file - df2 = pd.read_csv(csv_file_path, on_bad_lines='skip') - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + df2 = pd.read_csv(csv_file_path, on_bad_lines="skip") + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data = {'col1': [1, 2.2, -3.3, 4.0, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data) # write to xlsx file - xlsx_file_path = tmp_path.joinpath('example.xlsx') + xlsx_file_path = tmp_path.joinpath("example.xlsx") df1.to_excel(xlsx_file_path, index=False) # read from xlsx file df2 = pd.read_excel(xlsx_file_path) - assert df2[df2.columns[0]].to_list() == data['col1'] - assert df2[df2.columns[1]].to_list() == data['col2'] + assert df2[df2.columns[0]].to_list() == data["col1"] + assert df2[df2.columns[1]].to_list() == data["col2"] def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - data1 = {'col1': [1, 2, 3, 4, 5], - 'col2': ['A', 'B', 'C', 'D', 'E']} + data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]} df1 = pd.DataFrame(data1) - data2 = {'col1': [6, 7, 8, 9, 10], - 'col2': ['F', 'G', 'H', 'I', 'J']} + data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]} df2 = pd.DataFrame(data2) # write to xlsx file with sheets - xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx') - sheet1 = 'Sheet1' - sheet2 = 'Sheet2' + xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx") + sheet1 = "Sheet1" + sheet2 = "Sheet2" with pd.ExcelWriter(xlsx_file_path) as excel_writer: df1.to_excel(excel_writer, sheet_name=sheet1, index=False) df2.to_excel(excel_writer, sheet_name=sheet2, index=False) @@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch): # read from xlsx file with sheets with pd.ExcelFile(xlsx_file_path) as excel_file: df1 = pd.read_excel(excel_file, sheet_name=sheet1) - assert df1[df1.columns[0]].to_list() == data1['col1'] - assert df1[df1.columns[1]].to_list() == data1['col2'] + assert df1[df1.columns[0]].to_list() == data1["col1"] + assert df1[df1.columns[1]].to_list() == data1["col2"] df2 = pd.read_excel(excel_file, sheet_name=sheet2) - assert df2[df2.columns[0]].to_list() == data2['col1'] - assert df2[df2.columns[1]].to_list() == data2['col2'] + assert df2[df2.columns[0]].to_list() == data2["col1"] + assert df2[df2.columns[1]].to_list() == data2["col2"] diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index a979b77d70..2dc51252f0 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None: private_rsa_key = RSA.import_key(private_key) private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key) - raw_text = 'raw_text' + raw_text = "raw_text" raw_text_bytes = raw_text.encode() # RSA encryption by public key and decryption by private key diff --git a/api/tests/unit_tests/libs/test_yarl.py b/api/tests/unit_tests/libs/test_yarl.py index 75a5344126..b9aee4af5f 100644 --- a/api/tests/unit_tests/libs/test_yarl.py +++ b/api/tests/unit_tests/libs/test_yarl.py @@ -3,21 +3,21 @@ from yarl import URL def test_yarl_urls(): - expected_1 = 'https://dify.ai/api' - assert str(URL('https://dify.ai') / 'api') == expected_1 - assert str(URL('https://dify.ai/') / 'api') == expected_1 + expected_1 = "https://dify.ai/api" + assert str(URL("https://dify.ai") / "api") == expected_1 + assert str(URL("https://dify.ai/") / "api") == expected_1 - expected_2 = 'http://dify.ai:12345/api' - assert str(URL('http://dify.ai:12345') / 'api') == expected_2 - assert str(URL('http://dify.ai:12345/') / 'api') == expected_2 + expected_2 = "http://dify.ai:12345/api" + assert str(URL("http://dify.ai:12345") / "api") == expected_2 + assert str(URL("http://dify.ai:12345/") / "api") == expected_2 - expected_3 = 'https://dify.ai/api/v1' - assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3 - assert str(URL('https://dify.ai') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/') / 'api/v1') == expected_3 - assert str(URL('https://dify.ai/api') / 'v1') == expected_3 - assert str(URL('https://dify.ai/api/') / 'v1') == expected_3 + expected_3 = "https://dify.ai/api/v1" + assert str(URL("https://dify.ai") / "api" / "v1") == expected_3 + assert str(URL("https://dify.ai") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/") / "api/v1") == expected_3 + assert str(URL("https://dify.ai/api") / "v1") == expected_3 + assert str(URL("https://dify.ai/api/") / "v1") == expected_3 with pytest.raises(ValueError) as e1: - str(URL('https://dify.ai') / '/api') + str(URL("https://dify.ai") / "/api") assert str(e1.value) == "Appending path '/api' starting from slash is forbidden" diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py index 006b99fb7d..026912ffbe 100644 --- a/api/tests/unit_tests/models/test_account.py +++ b/api/tests/unit_tests/models/test_account.py @@ -2,13 +2,13 @@ from models.account import TenantAccountRole def test_account_is_privileged_role() -> None: - assert TenantAccountRole.ADMIN == 'admin' - assert TenantAccountRole.OWNER == 'owner' - assert TenantAccountRole.EDITOR == 'editor' - assert TenantAccountRole.NORMAL == 'normal' + assert TenantAccountRole.ADMIN == "admin" + assert TenantAccountRole.OWNER == "owner" + assert TenantAccountRole.EDITOR == "editor" + assert TenantAccountRole.NORMAL == "normal" assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN) assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL) assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR) - assert not TenantAccountRole.is_privileged_role('') + assert not TenantAccountRole.is_privileged_role("") diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 9e16010d7e..7968347dec 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -7,19 +7,19 @@ from models import ConversationVariable def test_from_variable_and_to_variable(): variable = factory.build_variable_from_mapping( { - 'id': str(uuid4()), - 'name': 'name', - 'value_type': SegmentType.OBJECT, - 'value': { - 'key': { - 'key': 'value', + "id": str(uuid4()), + "name": "name", + "value_type": SegmentType.OBJECT, + "value": { + "key": { + "key": "value", } }, } ) conversation_variable = ConversationVariable.from_variable( - app_id='app_id', conversation_id='conversation_id', variable=variable + app_id="app_id", conversation_id="conversation_id", variable=variable ) assert conversation_variable.to_variable() == variable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index bea896b83a..40483d7e3a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -8,30 +8,30 @@ from models.workflow import Workflow def test_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -42,30 +42,30 @@ def test_environment_variables(): def test_update_environment_variables(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) - variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) - variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) - variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) + variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) + variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) + variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): variables = [variable1, variable2, variable3, variable4] @@ -76,28 +76,28 @@ def test_update_environment_variables(): # Update the name of variable3 and keep the value as it is variables[2] = variable3.model_copy( update={ - 'name': 'new name', - 'value': HIDDEN_VALUE, + "name": "new name", + "value": HIDDEN_VALUE, } ) workflow.environment_variables = variables - assert workflow.environment_variables[2].name == 'new name' + assert workflow.environment_variables[2].name == "new name" assert workflow.environment_variables[2].value == variable3.value def test_to_dict(): - contexts.tenant_id.set('tenant_id') + contexts.tenant_id.set("tenant_id") # Create a Workflow instance workflow = Workflow( - tenant_id='tenant_id', - app_id='app_id', - type='workflow', - version='draft', - graph='{}', - features='{}', - created_by='account_id', + tenant_id="tenant_id", + app_id="app_id", + type="workflow", + version="draft", + graph="{}", + features="{}", + created_by="account_id", environment_variables=[], conversation_variables=[], ) @@ -105,19 +105,19 @@ def test_to_dict(): # Create some EnvironmentVariable instances with ( - mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), - mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), + mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [ - SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}), - StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}), + SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}), + StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}), ] workflow_dict = workflow.to_dict() - assert workflow_dict['environment_variables'][0]['value'] == '' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "" + assert workflow_dict["environment_variables"][1]["value"] == "text" workflow_dict = workflow.to_dict(include_secret=True) - assert workflow_dict['environment_variables'][0]['value'] == 'secret' - assert workflow_dict['environment_variables'][1]['value'] == 'text' + assert workflow_dict["environment_variables"][0]["value"] == "secret" + assert workflow_dict["environment_variables"][1]["value"] == "text" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index a45423bf39..805d92dfc9 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -83,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -105,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -153,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): external_data_variables = [ ExternalDataVariableEntity( - variable="external_variable", - type="api", - config={ - "api_based_extension_id": api_based_extension_id - } + variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} ) ] nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, - variables=default_variables, - external_data_variables=external_data_variables + app_model=app_model, variables=default_variables, external_data_variables=external_data_variables ) assert len(nodes) == 2 @@ -175,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): assert http_request_node["data"]["method"] == "post" assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == { - "type": "bearer", - "api_key": "api_key" - } + assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} assert http_request_node["data"]["body"]["type"] == "json" body_data = http_request_node["data"]["body"]["data"] @@ -207,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["sys", "query"] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -251,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=5, score_threshold=0.8, - reranking_model={ - 'reranking_provider_name': 'cohere', - 'reranking_model_name': 'rerank-english-v2.0' - }, - reranking_enabled=True - ) + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), ) - model_config = ModelConfigEntity( - provider='openai', - model='gpt-4', - mode='chat', - parameters={}, - stop=[] - ) + model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=dataset_config, - model_config=model_config + new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert (node["data"]["retrieval_mode"] - == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value assert node["data"]["multiple_retrieval_config"] == { "top_k": dataset_config.retrieve_config.top_k, "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model + "reranking_model": dataset_config.retrieve_config.reranking_model, } @@ -293,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -308,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -316,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): @@ -337,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -352,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}." + simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", ) llm_node = workflow_converter._convert_to_llm_node( @@ -360,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value template = prompt_template.simple_prompt_template for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + '\n' - assert llm_node["data"]['context']['enabled'] is False + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template + "\n" + assert llm_node["data"]["context"]["enabled"] is False def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): @@ -381,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -396,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) prompt_template = PromptTemplateEntity( prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ]) + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -409,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables) new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], list) assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) template = prompt_template.advanced_chat_prompt_template.messages[0].text for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"][0]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"][0]["text"] == template def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): @@ -431,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var workflow_converter = WorkflowConverter() start_node = workflow_converter._convert_to_start_node(default_variables) graph = { - "nodes": [ - start_node - ], - "edges": [] # no need + "nodes": [start_node], + "edges": [], # no need } model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = 'openai' + model_config_mock.provider = "openai" model_config_mock.model = model model_config_mock.mode = model_mode.value model_config_mock.parameters = {} @@ -448,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var prompt_type=PromptTemplateEntity.PromptType.ADVANCED, advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" - "Human: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) - ) + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), + ), ) llm_node = workflow_converter._convert_to_llm_node( @@ -461,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var new_app_mode=new_app_mode, model_config=model_config_mock, graph=graph, - prompt_template=prompt_template + prompt_template=prompt_template, ) assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]['name'] == model - assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["model"]["name"] == model + assert llm_node["data"]["model"]["mode"] == model_mode.value assert isinstance(llm_node["data"]["prompt_template"], dict) template = prompt_template.advanced_completion_prompt_template.prompt for v in default_variables: - template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}') - assert llm_node["data"]["prompt_template"]['text'] == template + template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") + assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 1235e559c9..29558a93c2 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -8,8 +8,9 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit @pytest.fixture def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions.yaml").write_text( + dedent( + """\ - first - second # - commented @@ -17,57 +18,54 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: - 9999999999999 - forth - """)) + """ + ) + ) return str(tmp_path) @pytest.fixture def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) - tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( - """\ + tmp_path.joinpath("example_positions_all_commented.yaml").write_text( + dedent( + """\ # - commented1 # - commented2 - - - """)) + """ + ) + ) return str(tmp_path) def test_position_helper(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml') + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") assert len(position_map) == 4 assert position_map == { - 'first': 0, - 'second': 1, - 'third': 2, - 'forth': 3, + "first": 0, + "second": 1, + "third": 2, + "forth": 3, } def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): position_map = get_position_map( - folder_path=prepare_empty_commented_positions_yaml, - file_name='example_positions_all_commented.yaml') + folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml" + ) assert position_map == {} def test_excluded_position_data(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml' - ) - pin_list = ['forth', 'first'] + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] include_set = set() - exclude_set = {'9999999999999'} + exclude_set = {"9999999999999"} - position_map = pin_position_map( - original_position_map=position_map, - pin_list=pin_list - ) + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) data = [ "forth", @@ -90,22 +88,16 @@ def test_excluded_position_data(prepare_example_positions_yaml): ) # assert the result in the correct order - assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2'] + assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"] def test_included_position_data(prepare_example_positions_yaml): - position_map = get_position_map( - folder_path=prepare_example_positions_yaml, - file_name='example_positions.yaml' - ) - pin_list = ['forth', 'first'] - include_set = {'forth', 'first'} + position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") + pin_list = ["forth", "first"] + include_set = {"forth", "first"} exclude_set = {} - position_map = pin_position_map( - original_position_map=position_map, - pin_list=pin_list - ) + position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list) data = [ "forth", @@ -128,4 +120,4 @@ def test_included_position_data(prepare_example_positions_yaml): ) # assert the result in the correct order - assert sorted_data == ['forth', 'first'] + assert sorted_data == ["forth", "first"] diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index c0452b4e4d..95b93651d5 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -5,17 +5,18 @@ from yaml import YAMLError from core.tools.utils.yaml_utils import load_yaml_file -EXAMPLE_YAML_FILE = 'example_yaml.yaml' -INVALID_YAML_FILE = 'invalid_yaml.yaml' -NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' +EXAMPLE_YAML_FILE = "example_yaml.yaml" +INVALID_YAML_FILE = "invalid_yaml.yaml" +NON_EXISTING_YAML_FILE = "non_existing_file.yaml" @pytest.fixture def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: - Java - C++ empty_key: - """)) + """ + ) + ) return str(file_path) @@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: monkeypatch.chdir(tmp_path) file_path = tmp_path.joinpath(INVALID_YAML_FILE) - file_path.write_text(dedent( - """\ + file_path.write_text( + dedent( + """\ address: city: Example City country: Example Country @@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: - Python - Java - C++ - """)) + """ + ) + ) return str(file_path) def test_load_yaml_non_existing_file(): assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} - assert load_yaml_file(file_path='') == {} + assert load_yaml_file(file_path="") == {} with pytest.raises(FileNotFoundError): load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) @@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file(): def test_load_valid_yaml_file(prepare_example_yaml_file): yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) assert len(yaml_data) > 0 - assert yaml_data['age'] == 30 - assert yaml_data['gender'] == 'male' - assert yaml_data['address']['city'] == 'Example City' - assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} - assert yaml_data.get('empty_key') is None - assert yaml_data.get('non_existed_key') is None + assert yaml_data["age"] == 30 + assert yaml_data["gender"] == "male" + assert yaml_data["address"]["city"] == "Example City" + assert set(yaml_data["languages"]) == {"Python", "Java", "C++"} + assert yaml_data.get("empty_key") is None + assert yaml_data.get("non_existed_key") is None def test_load_invalid_yaml_file(prepare_invalid_yaml_file):