From 3c91f9b5abc51d908c36ff033829b2e5a54fa8e8 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:26:15 +0800 Subject: [PATCH 001/200] fix: dataset segements api (#2766) --- api/controllers/service_api/dataset/segment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index d4a6b6aa4f..3f7cfcaea8 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource): parser.add_argument('segments', type=dict, required=False, nullable=True, location='json') args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segments'], document) - segment = SegmentService.update_segment(args['segments'], segment, document, dataset) + SegmentService.segment_create_args_validate(args, document) + segment = SegmentService.update_segment(args, segment, document, dataset) return { 'data': marshal(segment, segment_fields), 'doc_form': document.doc_form From 53cd125780cfe93db45dc0419cd5cc919f7eb8d5 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:27:00 +0800 Subject: [PATCH 002/200] fix: deep copy of model-tool label (#2775) --- .../tools/provider/model_tool_provider.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py index 733b2fd02a..ef47e9aae9 100644 --- a/api/core/tools/provider/model_tool_provider.py +++ b/api/core/tools/provider/model_tool_provider.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Any from core.entities.model_entities import ModelStatus @@ -59,10 +60,16 @@ class ModelToolProviderController(ToolProviderController): # override the configuration if model_tool_configuration.label: - if model_tool_configuration.label.en_US: - configuration.provider.label.en_US = model_tool_configuration.label.en_US - if model_tool_configuration.label.zh_Hans: - configuration.provider.label.zh_Hans = model_tool_configuration.label.zh_Hans + label = deepcopy(model_tool_configuration.label) + if label.en_US: + label.en_US = model_tool_configuration.label.en_US + if label.zh_Hans: + label.zh_Hans = model_tool_configuration.label.zh_Hans + else: + label = I18nObject( + en_US=configuration.provider.label.en_US, + zh_Hans=configuration.provider.label.zh_Hans + ) return ModelToolProviderController( is_active=is_active, @@ -70,12 +77,12 @@ class ModelToolProviderController(ToolProviderController): author='Dify', name=configuration.provider.provider, description=I18nObject( - zh_Hans=f'{configuration.provider.label.zh_Hans} 模型能力提供商', - en_US=f'{configuration.provider.label.en_US} model capability provider' + zh_Hans=f'{label.zh_Hans} 模型能力提供商', + en_US=f'{label.en_US} model capability provider' ), label=I18nObject( - zh_Hans=configuration.provider.label.zh_Hans, - en_US=configuration.provider.label.en_US + zh_Hans=label.zh_Hans, + en_US=label.en_US ), icon=configuration.provider.icon_small.en_US, ), From 82d0732c12fe3d547d00b51cd82930f74e066b6e Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:04:09 +0800 Subject: [PATCH 003/200] fix: aippt default styles (#2779) --- api/core/tools/provider/builtin/aippt/tools/aippt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 615033f5d9..81465848a2 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -455,7 +455,7 @@ class AIPPTGenerateTool(BuiltinTool): :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - return [], [] + raise Exception('Please provide aippt credentials') return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) From 7001e21e7d5bbbf30a3abc34f92200763148ff0c Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 11 Mar 2024 14:11:51 +0800 Subject: [PATCH 004/200] overview: fix filter today calc start & end (#2777) --- .../(appDetailLayout)/[appId]/overview/chartView.tsx | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index cd55faf8e1..93141dd86e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -29,7 +29,17 @@ export default function ChartView({ appId }: IChartViewProps) { const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) const onSelect = (item: Item) => { - setPeriod({ name: item.name, query: item.value === 'all' ? undefined : { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + if (item.value === 'all') { + setPeriod({ name: item.name, query: undefined }) + } + else if (item.value === 0) { + const startOfToday = today.startOf('day').format(queryDateFormat) + const endOfToday = today.endOf('day').format(queryDateFormat) + setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } }) + } + else { + setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + } } if (!response) From fdd211e3994bbfbb71b2d78534ebf6499ffb6f0e Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 11 Mar 2024 14:16:31 +0800 Subject: [PATCH 005/200] debug/chat: increase notify error duration to 3000 (#2778) --- api/README.md | 6 +++--- web/app/components/app/configuration/debug/index.tsx | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/README.md b/api/README.md index 1d3559c694..5c3a667801 100644 --- a/api/README.md +++ b/api/README.md @@ -5,7 +5,7 @@ 1. Start the docker-compose stack The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. - + ```bash cd ../docker docker-compose -f docker-compose.middleware.yaml -p dify up -d @@ -15,7 +15,7 @@ 3. Generate a `SECRET_KEY` in the `.env` file. ```bash - openssl rand -base64 42 + sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env ``` 3.5 If you use annaconda, create a new environment and activate it ```bash @@ -46,7 +46,7 @@ ``` pip install -r requirements.txt --upgrade --force-reinstall ``` - + 6. Start backend: ```bash flask run --host 0.0.0.0 --port=5001 --debug diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 6cf42d07b6..a7fd2d5ef7 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -130,7 +130,7 @@ const Debug: FC = ({ const { notify } = useContext(ToastContext) const logError = useCallback((message: string) => { - notify({ type: 'error', message }) + notify({ type: 'error', message, duration: 3000 }) }, [notify]) const [completionFiles, setCompletionFiles] = useState([]) From e184c8cb42af7f80f9def4107c82e7878d4e0277 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:53:40 +0800 Subject: [PATCH 006/200] Update README.md (#2780) --- README.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/README.md b/README.md index 14111ee060..7e2740b10e 100644 --- a/README.md +++ b/README.md @@ -21,17 +21,6 @@ Docker Pulls

-

- - Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈] - -

    -
  • US EST: 09:00 (9:00 AM)
  • -
  • CET: 15:00 (3:00 PM)
  • -
  • CST: 22:00 (10:00 PM)
  • -
-

-

Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs From 6454e1d644a45f01f9c8f5db16187bea81859575 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:36:56 +0800 Subject: [PATCH 007/200] chunk-overlap None check (#2781) Co-authored-by: jyong --- api/core/indexing_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 0cd9f9f646..dd46aa27dc 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -416,9 +416,14 @@ class IndexingRunner: if separator: separator = separator.replace('\\n', '\n') + if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']: + chunk_overlap = segmentation['chunk_overlap'] + else: + chunk_overlap = 0 + character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0), + chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ".", " ", ""], embedding_model_instance=embedding_model_instance From f8951d7f57a02bd3d1fcc9e8e4958804ba2055fb Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:21:41 +0800 Subject: [PATCH 008/200] fix: api tool provider not found (#2782) --- api/controllers/console/app/app.py | 41 +++++++++++---------- api/controllers/console/app/model_config.py | 38 +++++++++++-------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4b648a4e28..fb697b6418 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -247,27 +247,30 @@ class AppApi(Resource): for tool in agent_mode.get('tools') or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} - # override tool parameters - tool['tool_parameters'] = masked_parameter + # override tool parameters + tool['tool_parameters'] = masked_parameter + except Exception as e: + pass # override agent mode model_config.agent_mode = json.dumps(agent_mode) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 117007d055..436f8c1447 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -54,17 +54,20 @@ class ModelConfigResource(Resource): for tool in agent_mode.get('tools') or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + except Exception as e: + continue # get decrypted parameters if agent_tool_entity.tool_parameters: @@ -89,11 +92,14 @@ class ModelConfigResource(Resource): if key in tool_map: tool_runtime = tool_map[key] else: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + except Exception as e: + continue manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, From e54c9cd40160e565709fe5ad31cd069aaa97875b Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:48:21 +0800 Subject: [PATCH 009/200] Feat/open ai compatible functioncall (#2783) Co-authored-by: jyong --- .../model_providers/cohere/llm/llm.py | 2 +- .../mistralai/llm/mistral-large-latest.yaml | 2 +- .../mistralai/llm/mistral-medium-latest.yaml | 2 +- .../mistralai/llm/mistral-small-latest.yaml | 2 +- .../mistralai/llm/open-mistral-7b.yaml | 2 +- .../mistralai/llm/open-mixtral-8x7b.yaml | 2 +- .../openai_api_compatible/llm/llm.py | 111 ++++++++++++------ .../openai_api_compatible.yaml | 22 ++++ 8 files changed, 105 insertions(+), 40 deletions(-) diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 667ba4c78c..50805bce85 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["user_name"] = message.name return message_dict diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml index b729012c40..b8ed8ba934 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml index 6e586b4843..bf6f1b2d1d 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml index 4e7e6147f5..111cd05457 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml index 30454f7df2..4f72648662 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 2048 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml index a35cf0a9ae..719de29c3a 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index d294fcaa9c..8cfec0e34b 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import ( AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ generate custom model entities from credentials """ + support_function_call = False + features = [] + function_calling_type = credentials.get('function_calling_type', 'no_call') + if function_calling_type == 'function_call': + features = [ModelFeature.TOOL_CALL] + support_function_call = True + endpoint_url = credentials["endpoint_url"] + # if not endpoint_url.endswith('/'): + # endpoint_url += '/' + # if 'https://api.openai.com/v1/' == endpoint_url: + # features = [ModelFeature.STREAM_TOOL_CALL] entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features if support_function_call else [], model_properties={ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), ModelPropertyKey.MODE: credentials.get('mode'), @@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): max=1, precision=2 ), - ParameterRule( - name="top_k", - label=I18nObject(en_US="Top K"), - type=ParameterType.INT, - default=int(credentials.get('top_k', 1)), - min=1, - max=100 - ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), @@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): output=Decimal(credentials.get('output_price', 0)), unit=Decimal(credentials.get('unit', 0)), currency=credentials.get('currency', "USD") - ) + ), ) if credentials['mode'] == 'chat': @@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. + function_calling_type = credentials.get('function_calling_type', 'no_call') formatted_tools = [] if tools: - data["tool_choice"] = "auto" + if function_calling_type == 'function_call': + data['functions'] = [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } for tool in tools] + elif function_calling_type == 'tool_call': + data["tool_choice"] = "auto" - for tool in tools: - formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) + for tool in tools: + formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) - data["tools"] = formatted_tools + data["tools"] = formatted_tools if stop: data["stop"] = stop @@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: - #ignore sse comments + # ignore sse comments if chunk.startswith(':'): - continue + continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() chunk_json = None try: @@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): response_content = '' tool_calls = None - + function_calling_type = credentials.get('function_calling_type', 'no_call') if completion_type is LLMMode.CHAT: response_content = output.get('message', {})['content'] - tool_calls = output.get('message', {}).get('tool_calls') + if function_calling_type == 'tool_call': + tool_calls = output.get('message', {}).get('tool_calls') + elif function_calling_type == 'function_call': + tool_calls = output.get('message', {}).get('function_call') elif completion_type is LLMMode.COMPLETION: response_content = output['text'] @@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) + if function_calling_type == 'tool_call': + assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) + elif function_calling_type == 'function_call': + assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") if usage: @@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call - in - message.tool_calls] - # function_call = message.tool_calls[0] - # message_dict["function_call"] = { - # "name": function_call.function.name, - # "arguments": function_call.function.arguments, - # } + # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call + # in + # message.tool_calls] + + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } # message_dict = { - # "role": "function", + # "role": "tool", # "content": message.content, - # "name": message.tool_call_id + # "tool_call_id": message.tool_call_id # } + message_dict = { + "role": "function", + "content": message.content, + "name": message.tool_call_id + } else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["name"] = message.name return message_dict @@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_calls.append(tool_call) return tool_calls + + def _extract_response_function_call(self, response_function_call) \ + -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call['name'], + arguments=response_function_call['arguments'] + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call['name'], + type="function", + function=function + ) + + return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 213d334fe8..be99f7684c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -75,6 +75,28 @@ model_credential_schema: value: llm default: '4096' type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Support + zh_Hans: 支持 +# - value: tool_call +# label: +# en_US: Tool Call +# zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 - variable: stream_mode_delimiter label: zh_Hans: 流模式返回结果的分隔符 From 796c5626a7437ad8d77b1654dcaf401dd3105855 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:57:38 +0800 Subject: [PATCH 010/200] fix delete dataset when dataset has no document (#2789) Co-authored-by: jyong --- api/tasks/clean_dataset_task.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 37e109c847..b9737d7ddd 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -44,6 +44,10 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + if documents is None or len(documents) == 0: + logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) + return + index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean(dataset, None) From f49b1afd6c5969fe660826d98de92248421dfa2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Tue, 12 Mar 2024 12:06:35 +0800 Subject: [PATCH 011/200] feat:support azure tts (#2751) --- .../model_providers/azure_openai/_constant.py | 110 +++++++++++ .../azure_openai/azure_openai.yaml | 13 ++ .../azure_openai/tts/__init__.py | 0 .../model_providers/azure_openai/tts/tts.py | 174 ++++++++++++++++++ api/requirements.txt | 2 +- 5 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 api/core/model_runtime/model_providers/azure_openai/tts/__init__.py create mode 100644 api/core/model_runtime/model_providers/azure_openai/tts/tts.py diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 7fc0da73fb..9779fa71a0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -583,3 +583,113 @@ SPEECH2TEXT_BASE_MODELS = [ ) ) ] +TTS_BASE_MODELS = [ + AzureBaseModel( + base_model_name='tts-1', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={ + ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.VOICES: [ + { + 'mode': 'alloy', + 'name': 'Alloy', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'echo', + 'name': 'Echo', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'fable', + 'name': 'Fable', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'onyx', + 'name': 'Onyx', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'nova', + 'name': 'Nova', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'shimmer', + 'name': 'Shimmer', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + ], + ModelPropertyKey.WORD_LIMIT: 120, + ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.MAX_WORKERS: 5 + }, + pricing=PriceConfig( + input=0.015, + unit=0.001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='tts-1-hd', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={ + ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.VOICES: [ + { + 'mode': 'alloy', + 'name': 'Alloy', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'echo', + 'name': 'Echo', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'fable', + 'name': 'Fable', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'onyx', + 'name': 'Onyx', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'nova', + 'name': 'Nova', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + { + 'mode': 'shimmer', + 'name': 'Shimmer', + 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + }, + ], + ModelPropertyKey.WORD_LIMIT: 120, + ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.MAX_WORKERS: 5 + }, + pricing=PriceConfig( + input=0.03, + unit=0.001, + currency='USD', + ) + ) + ) +] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 6c56ccc920..58800ddee2 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -16,6 +16,7 @@ supported_model_types: - llm - text-embedding - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -118,6 +119,18 @@ model_credential_schema: show_on: - variable: __model_type value: speech2text + - label: + en_US: tts-1 + value: tts-1 + show_on: + - variable: __model_type + value: tts + - label: + en_US: tts-1-hd + value: tts-1-hd + show_on: + - variable: __model_type + value: tts placeholder: zh_Hans: 在此输入您的模型版本 en_US: Enter your model version diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/__init__.py b/api/core/model_runtime/model_providers/azure_openai/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py new file mode 100644 index 0000000000..585b061afe --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -0,0 +1,174 @@ +import concurrent.futures +import copy +from functools import reduce +from io import BytesIO +from typing import Optional + +from flask import Response, stream_with_context +from openai import AzureOpenAI +from pydub import AudioSegment + +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel +from extensions.ext_storage import storage + + +class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke(self, model: str, tenant_id: str, credentials: dict, + content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param streaming: output is streaming + :param user: unique user id + :return: text translated to audio file + """ + audio_type = self._get_model_audio_type(model, credentials) + if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + voice = self._get_model_default_voice(model, credentials) + if streaming: + return Response(stream_with_context(self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice)), + status=200, mimetype=f'audio/{audio_type}') + else: + return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :param user: unique user id + :return: text translated to audio file + """ + try: + self._tts_invoke( + model=model, + credentials=credentials, + content_text='Hello Dify!', + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: + """ + _tts_invoke text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + audio_type = self._get_model_audio_type(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + max_workers = self._get_model_workers_limit(model, credentials) + try: + sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + audio_bytes_list = list() + + # Create a thread pool and map the function to the list of sentences + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice, + credentials=credentials) for sentence in sentences] + for future in futures: + try: + if future.result(): + audio_bytes_list.append(future.result()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + if len(audio_bytes_list) > 0: + audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in + audio_bytes_list if audio_bytes] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + # Todo: To improve the streaming function + def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + voice = self._get_model_default_voice(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + audio_type = self._get_model_audio_type(model, credentials) + tts_file_id = self._get_file_name(content_text) + file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' + try: + client = AzureOpenAI(**credentials_kwargs) + sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + for sentence in sentences: + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + # response.stream_to_file(file_path) + storage.save(file_path, response.read()) + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _process_sentence(self, sentence: str, model: str, + voice, credentials: dict): + """ + _tts_invoke openai text2speech model api + + :param model: model name + :param credentials: model credentials + :param voice: model timbre + :param sentence: text content to be translated + :return: text translated to audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + client = AzureOpenAI(**credentials_kwargs) + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + if isinstance(response.read(), bytes): + return response.read() + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + return ai_model_entity.entity + + + @staticmethod + def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + for ai_model_entity in TTS_BASE_MODELS: + if ai_model_entity.base_model_name == base_model_name: + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy + + return None diff --git a/api/requirements.txt b/api/requirements.txt index 847903c4f4..7edd95a893 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -11,7 +11,7 @@ flask-cors~=4.0.0 gunicorn~=21.2.0 gevent~=23.9.1 langchain==0.0.250 -openai~=1.3.6 +openai~=1.13.3 tiktoken~=0.5.2 psycopg2-binary~=2.9.6 pycryptodome==3.19.1 From f82a64d149b001ac2b875bf5142acce0da7f5265 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 12 Mar 2024 13:45:59 +0800 Subject: [PATCH 012/200] =?UTF-8?q?feat:=20add=20DingTalk(=E9=92=89?= =?UTF-8?q?=E9=92=89)=20tool=20for=20sending=20message=20to=20chat=20group?= =?UTF-8?q?=20bot=20via=20webhook=20(#2693)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/tools/provider/_position.yaml | 1 + .../builtin/dingtalk/_assets/icon.svg | 7 ++ .../provider/builtin/dingtalk/dingtalk.py | 8 ++ .../provider/builtin/dingtalk/dingtalk.yaml | 13 +++ .../dingtalk/tools/dingtalk_group_bot.py | 83 +++++++++++++++++++ .../dingtalk/tools/dingtalk_group_bot.yaml | 52 ++++++++++++ 6 files changed, 164 insertions(+) create mode 100644 api/core/tools/provider/builtin/dingtalk/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/dingtalk/dingtalk.py create mode 100644 api/core/tools/provider/builtin/dingtalk/dingtalk.yaml create mode 100644 api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py create mode 100644 api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.yaml diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index ece9dbe159..a69f37618e 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -24,3 +24,4 @@ - gaode - wecom - qrcode +- dingtalk diff --git a/api/core/tools/provider/builtin/dingtalk/_assets/icon.svg b/api/core/tools/provider/builtin/dingtalk/_assets/icon.svg new file mode 100644 index 0000000000..b60653b7a5 --- /dev/null +++ b/api/core/tools/provider/builtin/dingtalk/_assets/icon.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dingtalk/dingtalk.py b/api/core/tools/provider/builtin/dingtalk/dingtalk.py new file mode 100644 index 0000000000..be1d5e099c --- /dev/null +++ b/api/core/tools/provider/builtin/dingtalk/dingtalk.py @@ -0,0 +1,8 @@ +from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class DingTalkProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + DingTalkGroupBotTool() + pass diff --git a/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml new file mode 100644 index 0000000000..ebe2e4fbaf --- /dev/null +++ b/api/core/tools/provider/builtin/dingtalk/dingtalk.yaml @@ -0,0 +1,13 @@ +identity: + author: Bowen Liang + name: dingtalk + label: + en_US: DingTalk + zh_Hans: 钉钉 + pt_BR: DingTalk + description: + en_US: DingTalk group robot + zh_Hans: 钉钉群机器人 + pt_BR: DingTalk group robot + icon: icon.svg +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py new file mode 100644 index 0000000000..c247c3bd6b --- /dev/null +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -0,0 +1,83 @@ +import base64 +import hashlib +import hmac +import logging +import time +import urllib.parse +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class DingTalkGroupBotTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access + """ + content = tool_parameters.get('content') + if not content: + return self.create_text_message('Invalid parameter content') + + access_token = tool_parameters.get('access_token') + if not access_token: + return self.create_text_message('Invalid parameter access_token. ' + 'Regarding information about security details,' + 'please refer to the DingTalk docs:' + 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + + sign_secret = tool_parameters.get('sign_secret') + if not sign_secret: + return self.create_text_message('Invalid parameter sign_secret. ' + 'Regarding information about security details,' + 'please refer to the DingTalk docs:' + 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + + msgtype = 'text' + api_url = 'https://oapi.dingtalk.com/robot/send' + headers = { + 'Content-Type': 'application/json', + } + params = { + 'access_token': access_token, + } + + self._apply_security_mechanism(params, sign_secret) + + payload = { + "msgtype": msgtype, + "text": { + "content": content, + } + } + + try: + res = httpx.post(api_url, headers=headers, params=params, json=payload) + if res.is_success: + return self.create_text_message("Text message sent successfully") + else: + return self.create_text_message( + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + except Exception as e: + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) + + @staticmethod + def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): + try: + timestamp = str(round(time.time() * 1000)) + secret_enc = sign_secret.encode('utf-8') + string_to_sign = f'{timestamp}\n{sign_secret}' + string_to_sign_enc = string_to_sign.encode('utf-8') + hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() + sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) + + params['timestamp'] = timestamp + params['sign'] = sign + except Exception: + msg = "Failed to apply security mechanism to the request." + logging.exception(msg) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.yaml b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.yaml new file mode 100644 index 0000000000..dc8a90b719 --- /dev/null +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.yaml @@ -0,0 +1,52 @@ +identity: + name: dingtalk_group_bot + author: Bowen Liang + label: + en_US: Send Group Message + zh_Hans: 发送群消息 + pt_BR: Send Group Message + icon: icon.svg +description: + human: + en_US: Sending a group message on DingTalk via the webhook of group bot + zh_Hans: 通过钉钉的群机器人webhook发送群消息 + pt_BR: Sending a group message on DingTalk via the webhook of group bot + llm: A tool for sending messages to a chat group on DingTalk(钉钉) . +parameters: + - name: access_token + type: secret-input + required: true + label: + en_US: access token + zh_Hans: access token + pt_BR: access token + human_description: + en_US: access_token in the group robot webhook + zh_Hans: 群自定义机器人webhook中access_token字段的值 + pt_BR: access_token in the group robot webhook + form: form + - name: sign_secret + type: secret-input + required: true + label: + en_US: secret key for signing + zh_Hans: 加签秘钥 + pt_BR: secret key for signing + human_description: + en_US: secret key for signing + zh_Hans: 加签秘钥 + pt_BR: secret key for signing + form: form + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + pt_BR: content + human_description: + en_US: Content to sent to the group. + zh_Hans: 群消息文本 + pt_BR: Content to sent to the group. + llm_description: Content of the message + form: llm From ce5b19d011b8c943ffd36059b94091646d01b0e6 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 14:01:24 +0800 Subject: [PATCH 013/200] bump version to 0.5.9 (#2794) --- api/config.py | 2 +- docker/docker-compose.yaml | 6 +++--- web/package.json | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/config.py b/api/config.py index 7c46426b47..a978a099b9 100644 --- a/api/config.py +++ b/api/config.py @@ -90,7 +90,7 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.5.8" + self.CURRENT_VERSION = "0.5.9" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index dfa01b6cef..d627bb3848 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3.1' services: # API service api: - image: langgenius/dify-api:0.5.8 + image: langgenius/dify-api:0.5.9 restart: always environment: # Startup mode, 'api' starts the API server. @@ -135,7 +135,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.5.8 + image: langgenius/dify-api:0.5.9 restart: always environment: # Startup mode, 'worker' starts the Celery worker for processing the queue. @@ -206,7 +206,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.5.8 + image: langgenius/dify-web:0.5.9 restart: always environment: EDITION: SELF_HOSTED diff --git a/web/package.json b/web/package.json index f160f29cc1..1ef344645a 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.5.8", + "version": "0.5.9", "private": true, "scripts": { "dev": "next dev", From f734cca33710365b02c3fa31af156fbb294c4349 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 12 Mar 2024 14:45:48 +0800 Subject: [PATCH 014/200] enhance: add stable diffusion user guide (#2795) --- .../provider/builtin/stablediffusion/stablediffusion.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml index 7b918cb854..e1161da5bb 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml @@ -33,3 +33,8 @@ credentials_for_provider: en_US: Please input your model zh_Hans: 请输入你的模型名称 pt_BR: Please input your model + help: + en_US: The model name of the StableDiffusion server + zh_Hans: StableDiffusion服务器的模型名称 + pt_BR: The model name of the StableDiffusion server + url: https://docs.dify.ai/tutorials/tool-configuration/stable-diffusion From 6bd273031714311003d4859534c909412f02d835 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:27:55 +0800 Subject: [PATCH 015/200] Fix/2770 suggestions for next steps (#2788) --- .../output_parser/suggested_questions_after_answer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index d8bb0809cf..e37142ec91 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -4,7 +4,6 @@ from typing import Any from langchain.schema import BaseOutputParser -from core.model_runtime.errors.invoke import InvokeError from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -14,11 +13,11 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT def parse(self, text: str) -> Any: - json_string = text.strip() - action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL) + action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) if action_match is not None: - json_obj = json.loads(action_match.group(1).strip(), strict=False) + json_obj = json.loads(action_match.group(0).strip()) else: - raise InvokeError("Could not parse LLM output: {text}") + json_obj= [] + print(f"Could not parse LLM output: {text}") return json_obj From 20bd49285bc714fa4eb87b54762625daf259d9f4 Mon Sep 17 00:00:00 2001 From: Weaxs <459312872@qq.com> Date: Tue, 12 Mar 2024 01:59:25 -0700 Subject: [PATCH 016/200] excel: get keys from every sheet (#2796) --- api/core/rag/extractor/excel_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 532391048b..30989be880 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -29,10 +29,10 @@ class ExcelExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load from file path.""" data = [] - keys = [] wb = load_workbook(filename=self._file_path, read_only=True) # loop over all sheets for sheet in wb: + keys = [] if 'A1:A1' == sheet.calculate_dimension(): sheet.reset_dimensions() for row in sheet.iter_rows(values_only=True): From 620360d41a60bdf50690cf658f3bf59a801c9007 Mon Sep 17 00:00:00 2001 From: Joshua <138381132+joshua20231026@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:02:46 +0800 Subject: [PATCH 017/200] Update README.md (#2799) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 7e2740b10e..a0ea4ab905 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5. +### Looking to purchase via AWS? +Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. + ## Dify vs. LangChain vs. Assistants API | Feature | Dify.AI | Assistants API | LangChain | From 6854a3fd26798f17b8ed3938590f3078d7fd6dd4 Mon Sep 17 00:00:00 2001 From: Joshua <138381132+joshua20231026@users.noreply.github.com> Date: Tue, 12 Mar 2024 18:14:07 +0800 Subject: [PATCH 018/200] Update README.md (#2800) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a0ea4ab905..557ac86770 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,8 @@

- - Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs + + 📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click.

From 870ca713df4fb7b83700205a25e1fe80e8da82f1 Mon Sep 17 00:00:00 2001 From: Yulong Wang Date: Tue, 12 Mar 2024 22:06:54 +0800 Subject: [PATCH 019/200] Refactor Markdown component to include paragraph after image (#2798) --- web/app/components/base/markdown.tsx | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index fb0cf5bb44..77ddcabed9 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -162,14 +162,17 @@ export function Markdown(props: { content: string; className?: string }) { const image = node.children[0] return ( - // eslint-disable-next-line @next/next/no-img-element - {image.properties.alt} + <> + {/* eslint-disable-next-line @next/next/no-img-element */} + {image.properties.alt} +

{paragraph.children.slice(1)}

+ ) } return

{paragraph.children}

From e5e97c0a0a9dddd82bf94365b640de5f8f91b35f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Tue, 12 Mar 2024 22:07:06 +0800 Subject: [PATCH 020/200] fix:change azure openai api_version default value to 2024-02-15-preview (#2797) --- .../model_runtime/model_providers/azure_openai/_constant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 9779fa71a0..559442274e 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ( PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2023-12-01-preview' +AZURE_OPENAI_API_VERSION = '2024-02-15-preview' def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: From 8f9125b08a7541280e1cf066d8c4ff6c4e69161f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Wed, 13 Mar 2024 13:00:46 +0800 Subject: [PATCH 021/200] fix:typo (#2808) --- api/core/model_runtime/entities/model_entities.py | 2 +- api/core/model_runtime/model_providers/__base/tts_model.py | 4 ++-- .../model_runtime/model_providers/azure_openai/_constant.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 60cb655c98..7dfd811b4f 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -133,7 +133,7 @@ class ModelPropertyKey(Enum): DEFAULT_VOICE = "default_voice" VOICES = "voices" WORD_LIMIT = "word_limit" - AUDOI_TYPE = "audio_type" + AUDIO_TYPE = "audio_type" MAX_WORKERS = "max_workers" diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 722d80c91e..22e546aad7 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -94,8 +94,8 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE] + if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: + return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 559442274e..dda0258812 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -628,7 +628,7 @@ TTS_BASE_MODELS = [ }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.AUDIO_TYPE: 'mp3', ModelPropertyKey.MAX_WORKERS: 5 }, pricing=PriceConfig( @@ -682,7 +682,7 @@ TTS_BASE_MODELS = [ }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDOI_TYPE: 'mp3', + ModelPropertyKey.AUDIO_TYPE: 'mp3', ModelPropertyKey.MAX_WORKERS: 5 }, pricing=PriceConfig( From 510f8ede10c05764b1b9c1f56318b86e0cf614d3 Mon Sep 17 00:00:00 2001 From: Mark Sun <50621638+WeepsDanky@users.noreply.github.com> Date: Wed, 13 Mar 2024 06:10:47 +0000 Subject: [PATCH 022/200] Improve automatic prompt generation (#2805) --- api/core/prompt/prompts.py | 38 +++++++++++++++++++++---------------- web/i18n/en-US/app-debug.ts | 2 +- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index dc8badbb98..72d8df7055 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -1,4 +1,4 @@ -# Written by YORKI MINAKO🤡 +# Written by YORKI MINAKO🤡, Edited by Xiaoyi CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc. MAKE SURE your output is the SAME language as the user's input! @@ -86,6 +86,21 @@ otherwise, it cannot exist as a variable in the variables. If you believe revising the original input will result in a better response from the language model, you may \ suggest revisions. +<> +Integrate the intended audience in the prompt e.g. the audience is an expert in the field. +Break down complex tasks into a sequence of simpler prompts in an interactive conversation. +Implement example-driven prompting (Use few-shot prompting). +When formatting your prompt start with Instruction followed by either Example if relevant. \ +Subsequently present your content. Use one or more line breaks to separate instructions examples questions context and input data. +Incorporate the following phrases: “Your task is” and “You MUST”. +Incorporate the following phrases: “You will be penalized”. +Use leading words like writing “think step by step”. +Add to your prompt the following phrase “Ensure that your answer is unbiased and does not rely on stereotypes”. +Assign a role to the large language models. +Use Delimiters. +To write an essay /text /paragraph /article or any type of text that should be detailed: “Write a detailed [essay/text/paragraph] for me on [topic] in detail by adding all the information necessary”. +Clearly state the requirements that the model must follow in order to produce content in the form of the keywords regulations hint or instructions + << FORMATTING >> Return a markdown code snippet with a JSON object formatted to look like, \ no any other string out of markdown code snippet: @@ -102,27 +117,18 @@ and fill in variables, with a welcome sentence, and keep TLDR. [EXAMPLE A] ```json { - "prompt": "Write a letter about love", - "variables": [], - "opening_statement": "Hi! I'm your love letter writer AI." + "prompt": "I need your help to translate the following {{Input_language}}paper paragraph into {{Target_language}}, in a style similar to a popular science magazine in {{Target_language}}. #### Rules Ensure accurate conveyance of the original text's facts and context during translation. Maintain the original paragraph format and retain technical terms and company abbreviations ", + "variables": ["Input_language", "Target_language"], + "opening_statement": " Hi. I am your translation assistant. I can help you with any translation and ensure accurate conveyance of information. " } ``` [EXAMPLE B] ```json { - "prompt": "Translate from {{lanA}} to {{lanB}}", - "variables": ["lanA", "lanB"], - "opening_statement": "Welcome to use translate app" -} -``` - -[EXAMPLE C] -```json -{ - "prompt": "Write a story about {{topic}}", - "variables": ["topic"], - "opening_statement": "I'm your story writer" + "prompt": "Your task is to review the provided meeting notes and create a concise summary that captures the essential information, focusing on key takeaways and action items assigned to specific individuals or departments during the meeting. Use clear and professional language, and organize the summary in a logical manner using appropriate formatting such as headings, subheadings, and bullet points. Ensure that the summary is easy to understand and provides a comprehensive but succinct overview of the meeting's content, with a particular focus on clearly indicating who is responsible for each action item.", + "variables": ["meeting_notes"], + "opening_statement": "Hi! I'm your meeting notes summarizer AI. I can help you with any meeting notes and ensure accurate conveyance of information." } ``` diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index 82f9ae7926..51db00946d 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -205,7 +205,7 @@ const translation = { intendedAudience: 'Who is the intended audience?', intendedAudiencePlaceHolder: 'e.g. Student', solveProblem: 'What problems do they hope AI can solve for them?', - solveProblemPlaceHolder: 'e.g. Assessing academic performance', + solveProblemPlaceHolder: 'e.g. Extract insights and summarize information from long reports and articles', generate: 'Generate', audiencesRequired: 'Audiences required', problemRequired: 'Problem required', From fd7aade26bd9724bf32971f8bd828ee71941fb64 Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Wed, 13 Mar 2024 15:38:10 +0800 Subject: [PATCH 023/200] Fix tts api err (#2809) Co-authored-by: luowei Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/controllers/service_api/app/audio.py | 8 ++++---- api/models/model.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 60ca2171d5..f6cad501f0 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -44,7 +44,7 @@ class AudioApi(Resource): response = AudioService.transcript_asr( tenant_id=app_model.tenant_id, file=file, - end_user=end_user + end_user=end_user.get_id() ) return response @@ -75,7 +75,7 @@ class AudioApi(Resource): class TextApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') @@ -86,8 +86,8 @@ class TextApi(Resource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=args['text'], - end_user=end_user, - voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + end_user=end_user.get_id(), + voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=args['streaming'] ) diff --git a/api/models/model.py b/api/models/model.py index d642d9a397..8776f89673 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -726,6 +726,7 @@ class MessageFile(db.Model): created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + class MessageAnnotation(db.Model): __tablename__ = 'message_annotations' __table_args__ = ( From a026c5fd084b187bec4a6d2c6074c092e0458783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=E1=BA=A1m=20Vi=E1=BA=BFt=20Ngh=C4=A9a?= <34742453+vigstudio@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:54:47 +0700 Subject: [PATCH 024/200] feat: add Vietnamese vi-VN (#2807) --- api/constants/languages.py | 308 +++++++++++----- web/i18n/language.ts | 9 + web/i18n/vi-VN/app-annotation.ts | 87 +++++ web/i18n/vi-VN/app-api.ts | 82 +++++ web/i18n/vi-VN/app-debug.ts | 409 +++++++++++++++++++++ web/i18n/vi-VN/app-log.ts | 69 ++++ web/i18n/vi-VN/app-overview.ts | 139 +++++++ web/i18n/vi-VN/app.ts | 54 +++ web/i18n/vi-VN/billing.ts | 115 ++++++ web/i18n/vi-VN/common.ts | 505 ++++++++++++++++++++++++++ web/i18n/vi-VN/custom.ts | 30 ++ web/i18n/vi-VN/dataset-creation.ts | 130 +++++++ web/i18n/vi-VN/dataset-documents.ts | 349 ++++++++++++++++++ web/i18n/vi-VN/dataset-hit-testing.ts | 28 ++ web/i18n/vi-VN/dataset-settings.ts | 33 ++ web/i18n/vi-VN/dataset.ts | 47 +++ web/i18n/vi-VN/explore.ts | 41 +++ web/i18n/vi-VN/layout.ts | 4 + web/i18n/vi-VN/login.ts | 59 +++ web/i18n/vi-VN/register.ts | 4 + web/i18n/vi-VN/share-app.ts | 74 ++++ web/i18n/vi-VN/tools.ts | 115 ++++++ 22 files changed, 2606 insertions(+), 85 deletions(-) create mode 100644 web/i18n/vi-VN/app-annotation.ts create mode 100644 web/i18n/vi-VN/app-api.ts create mode 100644 web/i18n/vi-VN/app-debug.ts create mode 100644 web/i18n/vi-VN/app-log.ts create mode 100644 web/i18n/vi-VN/app-overview.ts create mode 100644 web/i18n/vi-VN/app.ts create mode 100644 web/i18n/vi-VN/billing.ts create mode 100644 web/i18n/vi-VN/common.ts create mode 100644 web/i18n/vi-VN/custom.ts create mode 100644 web/i18n/vi-VN/dataset-creation.ts create mode 100644 web/i18n/vi-VN/dataset-documents.ts create mode 100644 web/i18n/vi-VN/dataset-hit-testing.ts create mode 100644 web/i18n/vi-VN/dataset-settings.ts create mode 100644 web/i18n/vi-VN/dataset.ts create mode 100644 web/i18n/vi-VN/explore.ts create mode 100644 web/i18n/vi-VN/layout.ts create mode 100644 web/i18n/vi-VN/login.ts create mode 100644 web/i18n/vi-VN/register.ts create mode 100644 web/i18n/vi-VN/share-app.ts create mode 100644 web/i18n/vi-VN/tools.ts diff --git a/api/constants/languages.py b/api/constants/languages.py index 284f3d8758..0ae69d77d2 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -2,7 +2,7 @@ import json from models.model import AppModelConfig -languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA'] +languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN'] language_timezone_mapping = { 'en-US': 'America/New_York', @@ -16,6 +16,7 @@ language_timezone_mapping = { 'ru-RU': 'Europe/Moscow', 'it-IT': 'Europe/Rome', 'uk-UA': 'Europe/Kyiv', + 'vi-VN': 'Asia/Ho_Chi_Minh', } @@ -79,6 +80,16 @@ user_input_form_template = { } } ], + "vi-VN": [ + { + "paragraph": { + "label": "Nội dung truy vấn", + "variable": "default_input", + "required": False, + "default": "" + } + } + ], } demo_model_templates = { @@ -208,7 +219,6 @@ demo_model_templates = { ) } ], - 'zh-Hans': [ { 'name': '翻译助手', @@ -335,91 +345,92 @@ demo_model_templates = { ) } ], - 'uk-UA': [{ - "name": "Помічник перекладу", - "icon": "", - "icon_background": "", - "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", - "mode": "completion", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo-instruct", - configs={ - "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", - "prompt_variables": [ - { - "key": "target_language", - "name": "Цільова мова", - "description": "Мова, на яку ви хочете перекласти.", - "type": "select", - "default": "Ukrainian", - "options": [ - "Chinese", - "English", - "Japanese", - "French", - "Russian", - "German", - "Spanish", - "Korean", - "Italian", - ], + 'uk-UA': [ + { + "name": "Помічник перекладу", + "icon": "", + "icon_background": "", + "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", + "mode": "completion", + "model_config": AppModelConfig( + provider="openai", + model_id="gpt-3.5-turbo-instruct", + configs={ + "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", + "prompt_variables": [ + { + "key": "target_language", + "name": "Цільова мова", + "description": "Мова, на яку ви хочете перекласти.", + "type": "select", + "default": "Ukrainian", + "options": [ + "Chinese", + "English", + "Japanese", + "French", + "Russian", + "German", + "Spanish", + "Korean", + "Italian", + ], + }, + ], + "completion_params": { + "max_token": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, }, - ], - "completion_params": { - "max_token": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, }, - }, - opening_statement="", - suggested_questions=None, - pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Цільова мова", - "variable": "target_language", - "description": "Мова, на яку ви хочете перекласти.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] + opening_statement="", + suggested_questions=None, + pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": { + "max_tokens": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, + }, + }), + user_input_form=json.dumps([ + { + "select": { + "label": "Цільова мова", + "variable": "target_language", + "description": "Мова, на яку ви хочете перекласти.", + "default": "Chinese", + "required": True, + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + ] + } + }, { + "paragraph": { + "label": "Запит", + "variable": "query", + "required": True, + "default": "" + } } - }, { - "paragraph": { - "label": "Запит", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, + ]) + ) + }, { "name": "AI інтерв’юер фронтенду", "icon": "", @@ -460,5 +471,132 @@ demo_model_templates = { ), } ], - + 'vi-VN': [ + { + 'name': 'Trợ lý dịch thuật', + 'icon': '', + 'icon_background': '', + 'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.', + 'mode': 'completion', + 'model_config': AppModelConfig( + provider='openai', + model_id='gpt-3.5-turbo-instruct', + configs={ + 'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n", + 'prompt_variables': [ + { + "key": "target_language", + "name": "Ngôn ngữ đích", + "description": "Ngôn ngữ bạn muốn dịch sang.", + "type": "select", + "default": "Vietnamese", + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + 'Vietnamese', + ] + } + ], + 'completion_params': { + 'max_token': 1000, + 'temperature': 0, + 'top_p': 0, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='', + suggested_questions=None, + pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": { + "max_tokens": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=json.dumps([ + { + "select": { + "label": "Ngôn ngữ đích", + "variable": "target_language", + "description": "Ngôn ngữ bạn muốn dịch sang.", + "default": "Vietnamese", + "required": True, + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + 'Vietnamese', + ] + } + }, { + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "" + } + } + ]) + ) + }, + { + 'name': 'Phỏng vấn front-end AI', + 'icon': '', + 'icon_background': '', + 'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.', + 'mode': 'chat', + 'model_config': AppModelConfig( + provider='openai', + model_id='gpt-3.5-turbo', + configs={ + 'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', + 'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", + 'prompt_variables': [], + 'completion_params': { + 'max_token': 300, + 'temperature': 0.8, + 'top_p': 0.9, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', + suggested_questions=None, + pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 300, + "temperature": 0.8, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=None + ) + } + ], } diff --git a/web/i18n/language.ts b/web/i18n/language.ts index 9b7f5a27f0..c72717dc84 100644 --- a/web/i18n/language.ts +++ b/web/i18n/language.ts @@ -16,6 +16,7 @@ export type I18nText = { 'ru-RU': string 'it-IT': string 'uk-UA': string + 'vi-VN': string } export const languages = [ @@ -97,6 +98,12 @@ export const languages = [ example: 'Привет, Dify!', supported: true, }, + { + value: 'vi-VN', + name: 'Tiếng Việt(Việt Nam)', + example: 'Xin chào, Dify!', + supported: true, + }, ] export const LanguagesSupported = languages.filter(item => item.supported).map(item => item.value) @@ -119,6 +126,7 @@ export const NOTICE_I18N = { ja_JP: '重要なお知らせ', ko_KR: '중요 공지', uk_UA: 'Важливе повідомлення', + vi_VN: 'Thông báo quan trọng', }, desc: { en_US: 'Our system will be unavailable from 19:00 to 24:00 UTC on August 28 for an upgrade. For questions, kindly contact our support team (support@dify.ai). We value your patience.', @@ -130,6 +138,7 @@ export const NOTICE_I18N = { ja_JP: 'Our system will be unavailable from 19:00 to 24:00 UTC on August 28 for an upgrade. For questions, kindly contact our support team (support@dify.ai). We value your patience.', ko_KR: 'Our system will be unavailable from 19:00 to 24:00 UTC on August 28 for an upgrade. For questions, kindly contact our support team (support@dify.ai). We value your patience.', uk_UA: 'Наша система буде недоступна з 19:00 до 24:00 UTC 28 серпня для оновлення. Якщо у вас виникнуть запитання, будь ласка, зв’яжіться з нашою службою підтримки (support@dify.ai). Дякуємо за терпіння.', + vi_VN: 'Hệ thống của chúng tôi sẽ ngừng hoạt động từ 19:00 đến 24:00 UTC vào ngày 28 tháng 8 để nâng cấp. Nếu có thắc mắc, vui lòng liên hệ với nhóm hỗ trợ của chúng tôi (support@dify.ai). Chúng tôi đánh giá cao sự kiên nhẫn của bạn.', }, href: '#', } diff --git a/web/i18n/vi-VN/app-annotation.ts b/web/i18n/vi-VN/app-annotation.ts new file mode 100644 index 0000000000..2b70b8608c --- /dev/null +++ b/web/i18n/vi-VN/app-annotation.ts @@ -0,0 +1,87 @@ +const translation = { + title: 'Chú thích', + name: 'Chú thích cho câu Trả lời', + editBy: 'Câu trả lời được chỉnh sửa bởi {{author}}', + noData: { + title: 'Không có chú thích', + description: 'Bạn có thể chỉnh sửa chú thích trong phần gỡ lỗi ứng dụng hoặc nhập hàng loạt chú thích vào đây để có phản hồi chất lượng cao.', + }, + table: { + header: { + question: 'câu hỏi', + answer: 'câu trả lời', + createdAt: 'tạo lúc', + hits: 'lượt truy cập', + actions: 'hành động', + addAnnotation: 'Thêm chú thích', + bulkImport: 'Nhập hàng loạt', + bulkExport: 'Xuất hoàng loạt', + clearAll: 'Xóa tất cả chú thích', + }, + }, + editModal: { + title: 'Chỉnh sửa chú thích Trả lời', + queryName: 'Truy vấn người dùng', + answerName: 'Câu chuyện của BOT', + yourAnswer: 'Câu trả lời của bạn', + answerPlaceholder: 'Nhập câu trả lời của bạn vào đây', + yourQuery: 'Truy vấn của bạn', + queryPlaceholder: 'Nhập truy vấn của bạn ở đây', + removeThisCache: 'Xóa chú thích này', + createdAt: 'Được tạo lúc', + }, + addModal: { + title: 'Thêm chú thích Trả lời', + queryName: 'Câu hỏi', + answerName: 'Trả lời', + answerPlaceholder: 'Nhập câu trả lời vào đây', + queryPlaceholder: 'Nhập truy vấn ở đây', + createNext: 'Thêm một phản hồi có chú thích khác', + }, + batchModal: { + title: 'Nhập hàng loạt', + csvUploadTitle: 'Kéo và thả tệp CSV của bạn vào đây hoặc ', + browse: 'duyệt', + tip: 'Tệp CSV phải tuân theo cấu trúc sau:', + question: 'câu hỏi', + answer: 'trả lời', + contentTitle: 'đoạn nội dung', + content: 'nội dung', + template: 'Tải mẫu tại đây', + cancel: 'Bỏ', + run: 'Run Batch', + runError: 'Run batch failed', + processing: 'In batch processing', + completed: 'Import completed', + error: 'Import Error', + ok: 'OK', + }, + errorMessage: { + answerRequired: 'Câu trả lời là bắt buộc', + queryRequired: 'Câu hỏi là bắt buộc', + }, + viewModal: { + annotatedResponse: 'Chú thích Trả lời', + hitHistory: 'Lịch sử lượt truy cập', + hit: 'Truy cập', + hits: 'Lượt truy cập', + noHitHistory: 'Không có lịch sử truy cập', + }, + hitHistoryTable: { + query: 'Truy vấn', + match: 'Chính xác', + response: 'Phản ứng', + source: 'Nguồn', + score: 'Điểm', + time: 'Thời gian', + }, + initSetup: { + title: 'Chú thích Trả lời Thiết lập ban đầu', + configTitle: 'Thiết lập trả lời chú thích', + confirmBtn: 'Lưu & Kích hoạt', + configConfirmBtn: 'Lưu', + }, + embeddingModelSwitchTip: 'Mô hình vector hóa văn bản chú thích, mô hình chuyển đổi sẽ được nhúng lại, dẫn đến phát sinh thêm chi phí.', +} + +export default translation diff --git a/web/i18n/vi-VN/app-api.ts b/web/i18n/vi-VN/app-api.ts new file mode 100644 index 0000000000..67c5a8bfa7 --- /dev/null +++ b/web/i18n/vi-VN/app-api.ts @@ -0,0 +1,82 @@ +const translation = { + apiServer: 'API Server', + apiKey: 'API Key', + status: 'Trạng thái', + disabled: 'Tắt', + ok: 'Đang hoạt động', + copy: 'Sao chép', + copied: 'Đã chép', + play: 'Chạy', + pause: 'Dừng', + playing: 'Đang chạy', + merMaind: { + rerender: 'Redo Rerender', + }, + never: 'Không bao giờ', + apiKeyModal: { + apiSecretKey: 'API khoá bí mật', + apiSecretKeyTips: 'Để ngăn chặn việc lạm dụng API, hãy bảo vệ Khóa API của bạn. Tránh sử dụng nó dưới dạng văn bản thuần túy trong mã giao diện người dùng. :)', + createNewSecretKey: 'Tạo khóa bí mật mới', + secretKey: 'Khóa bí mật', + created: 'CREATED', + lastUsed: 'LAST USED', + generateTips: 'Giữ chìa khóa này ở nơi an toàn và dễ tiếp cận.', + }, + actionMsg: { + deleteConfirmTitle: 'Xóa khóa bí mật này?', + deleteConfirmTips: 'Hành động này không thể được hoàn tác.', + ok: 'OK', + }, + completionMode: { + title: 'Completion App API', + info: 'Đối với việc tạo văn bản chất lượng cao, như bài viết, tóm tắt và dịch thuật, hãy sử dụng API hoàn thành tin nhắn với đầu vào người dùng. Việc tạo văn bản dựa trên các thông số mô hình và mẫu đề xuất được thiết lập trong Dify Prompt Engineering.', + createCompletionApi: 'Tạo tin nhắn hoàn thành', + createCompletionApiTip: 'Tạo một tin nhắn hoàn thành để hỗ trợ chế độ câu hỏi và trả lời.', + inputsTips: '(Tùy chọn) Cung cấp các trường đầu vào người dùng dưới dạng cặp khóa-giá trị, tương ứng với các biến trong Prompt Eng. Khóa là tên biến, Giá trị là giá trị tham số. Nếu loại trường là Lựa chọn, Giá trị đã gửi phải là một trong các lựa chọn đã thiết lập trước.', + queryTips: 'Nội dung văn bản đầu vào của người dùng.', + blocking: 'Loại chặn, đợi để thực hiện hoàn tất và trả kết quả. (Yêu cầu có thể bị gián đoạn nếu quá trình kéo dài)', + streaming: 'trả về dữ liệu theo luồng. Thực hiện trả dữ liệu theo luồng dựa trên SSE (Sự kiện được gửi từ máy chủ).', + messageFeedbackApi: 'Phản hồi tin nhắn (thích)', + messageFeedbackApiTip: 'Đánh giá các tin nhắn nhận được thay mặt cho người dùng cuối với các lựa chọn thích hoặc không thích. Dữ liệu này hiển thị trên trang Nhật ký & Chú thích và được sử dụng cho việc điều chỉnh mô hình trong tương lai.', + messageIDTip: 'ID tin nhắn', + ratingTip: 'thích hoặc không thích, null là hủy bỏ', + parametersApi: 'Thu thập thông tin tham số ứng dụng', + parametersApiTip: 'Truy xuất các tham số Đầu vào được cấu hình, bao gồm tên biến, tên trường, loại và giá trị mặc định. Thường được sử dụng để hiển thị các trường này trong một biểu mẫu hoặc điền vào các giá trị mặc định sau khi máy khách tải.', + }, + chatMode: { + title: 'Chat App API', + info: 'Đối với ứng dụng trò chuyện linh hoạt sử dụng định dạng Câu hỏi và Trả lời, gọi API tin nhắn trò chuyện để bắt đầu cuộc trò chuyện. Duy trì cuộc trò chuyện liên tục bằng cách chuyển conversation_id đã trả về. Các tham số phản hồi và mẫu phụ thuộc vào các cài đặt của Dify Prompt Eng.', + createChatApi: 'Tạo tin nhắn trò chuyện', + createChatApiTip: 'Tạo một tin nhắn trò chuyện mới hoặc tiếp tục một cuộc trò chuyện đang tồn tại.', + inputsTips: '(Tùy chọn) Cung cấp các trường đầu vào người dùng dưới dạng cặp khóa-giá trị, tương ứng với các biến trong Prompt Eng. Khóa là tên biến, Giá trị là giá trị tham số. Nếu loại trường là Lựa chọn, Giá trị đã gửi phải là một trong các lựa chọn đã thiết lập trước.', + queryTips: 'Nội dung câu hỏi của người dùng', + blocking: 'Loại chặn, đợi để thực hiện hoàn tất và trả kết quả. (Yêu cầu có thể bị gián đoạn nếu quá trình kéo dài)', + streaming: 'trả về dữ liệu theo luồng. Thực hiện trả dữ liệu theo luồng dựa trên SSE (Sự kiện được gửi từ máy chủ).', + conversationIdTip: '(Tùy chọn) ID cuộc trò chuyện: để trống cho cuộc trò chuyện lần đầu; chuyển conversation_id từ ngữ cảnh để tiếp tục cuộc trò chuyện.', + messageFeedbackApi: 'Phản hồi của người dùng cuối về tin nhắn, như', + messageFeedbackApiTip: 'Đánh giá các tin nhắn nhận được thay mặt cho người dùng cuối với các lựa chọn thích hoặc không thích. Dữ liệu này hiển thị trên trang Nhật ký & Chú thích và được sử dụng cho việc điều chỉnh mô hình trong tương lai.', + messageIDTip: 'ID tin nhắn', + ratingTip: 'thích hoặc không thích, null là hủy bỏ', + chatMsgHistoryApi: 'Lấy lịch sử tin nhắn trò chuyện', + chatMsgHistoryApiTip: 'Trang đầu tiên trả về `limit` tin nhắn mới nhất, được sắp xếp theo thứ tự ngược lại.', + chatMsgHistoryConversationIdTip: 'ID Cuộc trò chuyện', + chatMsgHistoryFirstId: 'ID của bản ghi trò chuyện đầu tiên trên trang hiện tại. Giá trị mặc định là không có.', + chatMsgHistoryLimit: 'Bao nhiêu cuộc trò chuyện được trả lại trong một yêu cầu', + conversationsListApi: 'Lấy danh sách cuộc trò chuyện', + conversationsListApiTip: 'Lấy danh sách phiên của người dùng hiện tại. Theo mặc định, trả về 20 phiên cuối cùng.', + conversationsListFirstIdTip: 'ID của bản ghi cuối cùng trên trang hiện tại, mặc định không có.', + conversationsListLimitTip: 'Bao nhiêu cuộc trò chuyện được trả lại trong một yêu cầu', + conversationRenamingApi: 'Đổi tên cuộc trò chuyện', + conversationRenamingApiTip: 'Đổi tên cuộc trò chuyện; tên sẽ được hiển thị trong giao diện nhiều phiên.', + conversationRenamingNameTip: 'Tên mới', + parametersApi: 'Thu thập thông tin tham số ứng dụng', + parametersApiTip: 'Truy xuất các tham số Đầu vào được cấu hình, bao gồm tên biến, tên trường, loại và giá trị mặc định. Thường được sử dụng để hiển thị các trường này trong một biểu mẫu hoặc điền vào các giá trị mặc định sau khi máy khách tải.', + }, + develop: { + requestBody: 'Nội dung yêu cầu', + pathParams: 'Thông số đường dẫn', + query: 'Truy vấn', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/app-debug.ts b/web/i18n/vi-VN/app-debug.ts new file mode 100644 index 0000000000..7a65edc8b4 --- /dev/null +++ b/web/i18n/vi-VN/app-debug.ts @@ -0,0 +1,409 @@ +const translation = { + pageTitle: { + line1: 'PROMPT', + line2: 'Engineering', + }, + orchestrate: 'Orchestrate', + promptMode: { + simple: 'Chuyển sang Chế độ Chuyên gia để chỉnh sửa toàn bộ PROMPT', + advanced: 'Chế độ Chuyên gia', + switchBack: 'Chuyển về', + advancedWarning: { + title: 'Bạn đã chuyển sang Chế độ Chuyên gia, và một khi bạn sửa đổi PROMPT, bạn KHÔNG THỂ quay lại chế độ cơ bản.', + description: 'Trong Chế độ Chuyên gia, bạn có thể chỉnh sửa toàn bộ PROMPT.', + learnMore: 'Tìm hiểu thêm', + ok: 'OK', + }, + operation: { + addMessage: 'Thêm Tin nhắn', + }, + contextMissing: 'Thiếu thành phần Ngữ cảnh, hiệu quả của prompt có thể không tốt.', + }, + operation: { + applyConfig: 'Xuất bản', + resetConfig: 'Đặt lại', + debugConfig: 'Gỡ lỗi', + addFeature: 'Thêm Tính năng', + automatic: 'Tự động', + stopResponding: 'Dừng phản hồi', + agree: 'thích', + disagree: 'không thích', + cancelAgree: 'Hủy thích', + cancelDisagree: 'Hủy không thích', + userAction: 'Người dùng ', + }, + notSetAPIKey: { + title: 'Khóa nhà cung cấp LLM chưa được đặt', + trailFinished: 'Kết thúc dấu vết', + description: 'Khóa nhà cung cấp LLM chưa được đặt, và cần được đặt trước khi gỡ lỗi.', + settingBtn: 'Đi đến cài đặt', + }, + trailUseGPT4Info: { + title: 'Hiện không hỗ trợ gpt-4', + description: 'Sử dụng gpt-4, vui lòng đặt API Key.', + }, + feature: { + groupChat: { + title: 'Nâng cao trò chuyện', + description: 'Thêm cài đặt trước cuộc trò chuyện cho ứng dụng có thể cải thiện trải nghiệm người dùng.', + }, + groupExperience: { + title: 'Nâng cao trải nghiệm', + }, + conversationOpener: { + title: 'Khởi động cuộc trò chuyện', + description: 'Trong một ứng dụng trò chuyện, câu nói đầu tiên mà AI tự động nói với người dùng thường được sử dụng như một lời chào.', + }, + suggestedQuestionsAfterAnswer: { + title: 'Theo dõi', + description: 'Thiết lập đề xuất câu hỏi tiếp theo có thể mang lại trò chuyện tốt hơn cho người dùng.', + resDes: '3 đề xuất cho câu hỏi tiếp theo của người dùng.', + tryToAsk: 'Thử hỏi', + }, + moreLikeThis: { + title: 'Nhiều hơn như vậy', + description: 'Tạo nhiều văn bản cùng một lúc, và sau đó chỉnh sửa và tiếp tục tạo ra.', + generateNumTip: 'Số lượng mỗi lần tạo ra', + tip: 'Sử dụng tính năng này sẽ tốn thêm token.', + }, + speechToText: { + title: 'Chuyển đổi Giọng nói thành Văn bản', + description: 'Một khi được bật, bạn có thể sử dụng đầu vào giọng nói.', + resDes: 'Đầu vào Giọng nói đã được bật', + }, + textToSpeech: { + title: 'Chuyển đổi Văn bản thành Giọng nói', + description: 'Một khi được bật, văn bản có thể được chuyển đổi thành giọng nói.', + resDes: 'Chuyển đổi Văn bản thành Âm thanh đã được bật', + }, + citation: { + title: 'Trích dẫn và Ghi chú', + description: 'Một khi được bật, hiển thị tài liệu nguồn và phần được ghi nhận của nội dung được tạo ra.', + resDes: 'Trích dẫn và Ghi chú đã được bật', + }, + annotation: { + title: 'Phản hồi Chú thích', + description: 'Bạn có thể thêm phản hồi chất lượng cao vào bộ nhớ cache để ưu tiên phù hợp với các câu hỏi của người dùng tương tự.', + resDes: 'Phản hồi Chú thích đã được bật', + scoreThreshold: { + title: 'Ngưỡng Điểm', + description: 'Được sử dụng để đặt ngưỡng tương đồng cho phản hồi chú thích.', + easyMatch: 'Tương đồng Dễ dàng', + accurateMatch: 'Tương đồng Chính xác', + }, + matchVariable: { + title: 'Biến Phù hợp', + choosePlaceholder: 'Chọn biến phù hợp', + }, + cacheManagement: 'Chú thích', + cached: 'Đã ghi chú', + remove: 'Xóa', + removeConfirm: 'Xóa chú thích này?', + add: 'Thêm chú thích', + edit: 'Chỉnh sửa chú thích', + }, + dataSet: { + title: 'Ngữ cảnh', + noData: 'Bạn có thể nhập Dữ liệu như là ngữ cảnh', + words: 'Từ', + textBlocks: 'Khối Văn bản', + selectTitle: 'Chọn Kiến thức tham khảo', + selected: 'Kiến thức đã chọn', + noDataSet: 'Không tìm thấy Kiến thức', + toCreate: 'Đi tới tạo mới', + notSupportSelectMulti: 'Hiện chỉ hỗ trợ một Kiến thức', + queryVariable: { + title: 'Biến Truy vấn', + tip: 'Biến này sẽ được sử dụng làm đầu vào truy vấn cho việc truy xuất ngữ cảnh, lấy thông tin ngữ cảnh liên quan đến đầu vào của biến này.', + choosePlaceholder: 'Chọn biến truy vấn', + noVar: 'Không có biến', + noVarTip: 'xin vui lòng tạo một biến dưới phần Biến', + unableToQueryDataSet: 'Không thể truy vấn Kiến thức', + unableToQueryDataSetTip: 'Không thể truy vấn Kiến thức thành công, vui lòng chọn một biến truy vấn ngữ cảnh trong phần ngữ cảnh.', + ok: 'OK', + contextVarNotEmpty: 'biến truy vấn ngữ cảnh không thể trống', + deleteContextVarTitle: 'Xóa biến “{{varName}}”?', + deleteContextVarTip: 'Biến này đã được thiết lập là biến truy vấn ngữ cảnh, và việc loại bỏ nó sẽ ảnh hưởng đến việc sử dụng bình thường của Kiến thức. Nếu bạn vẫn cần xóa nó, vui lòng chọn lại nó trong phần ngữ cảnh.', + }, + }, + tools: { + title: 'Công cụ', + tips: 'Công cụ cung cấp một phương thức gọi API chuẩn, lấy đầu vào người dùng hoặc biến làm tham số yêu cầu để truy vấn dữ liệu bên ngoài như ngữ cảnh.', + toolsInUse: '{{count}} công cụ đang được sử dụng', + modal: { + title: 'Công cụ', + toolType: { + title: 'Loại Công cụ', + placeholder: 'Vui lòng chọn loại công cụ', + }, + name: { + title: 'Tên', + placeholder: 'Vui lòng nhập tên', + }, + variableName: { + title: 'Tên Biến', + placeholder: 'Vui lòng nhập tên biến', + }, + }, + }, + conversationHistory: { + title: 'Lịch sử Cuộc trò chuyện', + description: 'Đặt tên tiền tố cho các vai trò trong cuộc trò chuyện', + tip: 'Lịch sử Cuộc trò chuyện chưa được bật, vui lòng thêm vào phần prompt ở trên.', + learnMore: 'Tìm hiểu thêm', + editModal: { + title: 'Chỉnh sửa Tên Vai trò Cuộc trò chuyện', + userPrefix: 'Tiền tố Người dùng', + assistantPrefix: 'Tiền tố Trợ lý', + }, + }, + toolbox: { + title: 'HỘP CÔNG CỤ', + }, + moderation: { + title: 'Kiểm duyệt nội dung', + description: 'Bảo vệ đầu ra của mô hình bằng cách sử dụng API kiểm duyệt hoặc duy trì một danh sách từ nhạy cảm.', + allEnabled: 'Nội dung ĐẦU VÀO/ĐẦU RA Đã Bật', + inputEnabled: 'Nội dung ĐẦU VÀO Đã Bật', + outputEnabled: 'Nội dung ĐẦU RA Đã Bật', + modal: { + title: 'Cài đặt kiểm duyệt nội dung', + provider: { + title: 'Nhà cung cấp', + openai: 'Kiểm duyệt OpenAI', + openaiTip: { + prefix: 'Kiểm duyệt OpenAI yêu cầu một khóa API OpenAI được cấu hình trong ', + suffix: '.', + }, + keywords: 'Từ khóa', + }, + keywords: { + tip: 'Mỗi dòng một từ khóa, phân tách bằng các dòng. Tối đa 100 ký tự mỗi dòng.', + placeholder: 'Mỗi dòng một từ khóa, phân tách bằng các dòng', + line: 'Dòng', + }, + content: { + input: 'Kiểm duyệt Nội dung ĐẦU VÀO', + output: 'Kiểm duyệt Nội dung ĐẦU RA', + preset: 'Câu trả lời mẫu', + placeholder: 'Nội dung câu trả lời mẫu ở đây', + condition: 'Đã bật ít nhất một kiểm duyệt nội dung ĐẦU VÀO và ĐẦU RA', + fromApi: 'Câu trả lời mẫu được trả về bởi API', + errorMessage: 'Câu trả lời mẫu không thể trống', + supportMarkdown: 'Hỗ trợ Markdown', + }, + openaiNotConfig: { + before: 'Kiểm duyệt OpenAI yêu cầu một khóa API OpenAI được cấu hình trong', + after: '', + }, + }, + }, + }, + automatic: { + title: 'Tự động hóa triển khai ứng dụng', + description: 'Mô tả tình huống của bạn, Dify sẽ tự động hóa một ứng dụng cho bạn.', + intendedAudience: 'Ai là đối tượng mục tiêu?', + intendedAudiencePlaceHolder: 'ví dụ: Sinh viên', + solveProblem: 'Họ hy vọng AI có thể giải quyết vấn đề gì cho họ?', + solveProblemPlaceHolder: 'ví dụ: Đánh giá thành tích học tập', + generate: 'Tạo ra', + audiencesRequired: 'Yêu cầu Đối tượng mục tiêu', + problemRequired: 'Vấn đề cần thiết', + resTitle: 'Chúng tôi đã tự động hóa ứng dụng sau đây cho bạn.', + apply: 'Áp dụng tự động hóa này', + noData: 'Mô tả tình huống sử dụng của bạn ở bên trái, xem trước tự động hóa sẽ hiển thị ở đây.', + loading: 'Tự động hóa ứng dụng cho bạn...', + overwriteTitle: 'Ghi đè cấu hình hiện tại?', + overwriteMessage: 'Áp dụng tự động hóa này sẽ ghi đè lên cấu hình hiện tại.', + }, + resetConfig: { + title: 'Xác nhận đặt lại?', + message: + 'Đặt lại sẽ loại bỏ các thay đổi, khôi phục cấu hình đã xuất bản lần cuối.', + }, + errorMessage: { + nameOfKeyRequired: 'tên của khóa: {{key}} được yêu cầu', + valueOfVarRequired: 'giá trị {{key}} không thể trống', + queryRequired: 'Văn bản yêu cầu được yêu cầu.', + waitForResponse: + 'Vui lòng đợi phản hồi của tin nhắn trước để hoàn thành.', + waitForBatchResponse: + 'Vui lòng đợi phản hồi của tác vụ hàng loạt để hoàn thành.', + notSelectModel: 'Vui lòng chọn một mô hình', + waitForImgUpload: 'Vui lòng đợi hình ảnh được tải lên', + }, + chatSubTitle: 'Hướng dẫn', + completionSubTitle: 'Tiền Tố Lời Nhắc', + promptTip: + 'Lời nhắc hướng dẫn các phản hồi của AI với hướng dẫn và ràng buộc. Chèn biến như {{input}}. Lời nhắc này sẽ không được hiển thị cho người dùng.', + formattingChangedTitle: 'Định dạng đã thay đổi', + formattingChangedText: + 'Thay đổi định dạng sẽ đặt lại khu vực gỡ lỗi, bạn có chắc chắn không?', + variableTitle: 'Biến', + variableTip: + 'Người dùng điền các biến vào một biểu mẫu, tự động thay thế các biến trong lời nhắc.', + notSetVar: 'Biến cho phép người dùng giới thiệu các từ khóa lời nhắc hoặc mở đầu khi điền vào biểu mẫu. Bạn có thể thử nhập "{{input}}" trong các từ khóa lời nhắc.', + autoAddVar: 'Biến không xác định được tham chiếu trong tiền-lời nhắc, bạn có muốn thêm chúng vào biểu mẫu đầu vào người dùng không?', + variableTable: { + key: 'Khóa Biến', + name: 'Tên Trường Nhập Liệu Người Dùng', + optional: 'Tùy chọn', + type: 'Loại Nhập Liệu', + action: 'Hành động', + typeString: 'Chuỗi', + typeSelect: 'Chọn', + }, + varKeyError: { + canNoBeEmpty: 'Khóa biến không thể trống', + tooLong: 'Khóa biến: {{key}} quá dài. Không thể dài hơn 30 ký tự', + notValid: 'Khóa biến: {{key}} không hợp lệ. Chỉ có thể chứa chữ cái, số, và dấu gạch dưới', + notStartWithNumber: 'Khóa biến: {{key}} không thể bắt đầu bằng số', + keyAlreadyExists: 'Khóa biến: :{{key}} đã tồn tại', + }, + otherError: { + promptNoBeEmpty: 'Lời nhắc không thể trống', + historyNoBeEmpty: 'Lịch sử cuộc trò chuyện phải được thiết lập trong lời nhắc', + queryNoBeEmpty: 'Truy vấn phải được thiết lập trong lời nhắc', + }, + variableConig: { + modalTitle: 'Thiết lập trường', + description: 'Thiết lập cho biến {{varName}}', + fieldType: 'Loại Trường', + string: 'Văn Bản Ngắn', + paragraph: 'Đoạn Văn', + select: 'Chọn', + notSet: 'Không được thiết lập, hãy thử nhập {{input}} vào lời nhắc tiền.', + stringTitle: 'Tùy chọn hộp văn bản biểu mẫu', + maxLength: 'Độ dài Tối đa', + options: 'Tùy chọn', + addOption: 'Thêm tùy chọn', + apiBasedVar: 'Biến Dựa trên API', + }, + vision: { + name: 'Tầm nhìn', + description: 'Cho phép tầm nhìn sẽ cho phép mô hình nhận hình ảnh và trả lời các câu hỏi về chúng.', + settings: 'Cài đặt', + visionSettings: { + title: 'Cài đặt Tầm nhìn', + resolution: 'Độ phân giải', + resolutionTooltip: `Độ phân giải thấp sẽ cho phép mô hình nhận một phiên bản hình ảnh 512 x 512 thấp hơn, và đại diện cho hình ảnh với ngân sách 65 token. Điều này cho phép API trả về phản hồi nhanh hơn và tiêu thụ ít token đầu vào cho các trường hợp sử dụng không yêu cầu chi tiết cao. + \n + độ phân giải cao sẽ đầu tiên cho phép mô hình nhìn thấy hình ảnh thấp hơn và sau đó tạo ra các cắt chi tiết của hình ảnh đầu vào dưới dạng hình vuông 512px dựa trên kích thước hình ảnh đầu vào. Mỗi cắt chi tiết sử dụng hai lần ngân sách token cho tổng cộng 129 token.`, + high: 'Cao', + low: 'Thấp', + uploadMethod: 'Phương thức Tải lên', + both: 'Cả hai', + localUpload: 'Tải lên Nội bộ', + url: 'URL', + uploadLimit: 'Giới hạn Tải lên', + }, + }, + voice: { + name: 'Giọng nói', + defaultDisplay: 'Giọng mặc định', + description: 'Cài đặt giọng nói văn bản thành tiếng', + settings: 'Cài đặt', + voiceSettings: { + title: 'Cài đặt Giọng nói', + language: 'Ngôn ngữ', + resolutionTooltip: 'Giọng nói văn bản hỗ trợ ngôn ngữ。', + voice: 'Giọng', + }, + }, + openingStatement: { + title: 'Mở đầu Trò chuyện', + add: 'Thêm', + writeOpner: 'Viết câu mở đầu', + placeholder: 'Viết thông điệp mở đầu của bạn ở đây, bạn có thể sử dụng biến, hãy thử nhập {{biến}}.', + openingQuestion: 'Câu Hỏi Mở đầu', + noDataPlaceHolder: + 'Bắt đầu cuộc trò chuyện với người dùng có thể giúp AI thiết lập một mối quan hệ gần gũi hơn với họ trong các ứng dụng trò chuyện.', + varTip: 'Bạn có thể sử dụng biến, hãy thử nhập {{biến}}', + tooShort: 'Ít nhất 20 từ của lời nhắc ban đầu được yêu cầu để tạo ra các lời nhận đầu tiên cho cuộc trò chuyện.', + notIncludeKey: 'Lời nhắc ban đầu không bao gồm biến: {{khóa}}. Vui lòng thêm nó vào lời nhắc ban đầu.', + }, + modelConfig: { + model: 'Mô hình', + setTone: 'Thiết lập tông của phản hồi', + title: 'Mô hình và Tham số', + modeType: { + chat: 'Trò chuyện', + completion: 'Hoàn thành', + }, + }, + inputs: { + title: 'Gỡ Lỗi và Xem Trước', + noPrompt: 'Hãy thử viết một số lời nhắc trong trường tiền-lời nhắc', + userInputField: 'Trường Nhập Liệu Người Dùng', + noVar: 'Điền vào giá trị của biến, nó sẽ tự động thay thế từ khóa lời nhắc mỗi khi bắt đầu phiên mới.', + chatVarTip: + 'Điền vào giá trị của biến, nó sẽ tự động thay thế từ khóa lời nhắc mỗi khi bắt đầu phiên mới', + completionVarTip: + 'Điền vào giá trị của biến, nó sẽ tự động thay thế từ khóa lời nhắc mỗi khi một câu hỏi được gửi.', + previewTitle: 'Xem Trước Lời Nhắc', + queryTitle: 'Nội dung Truy vấn', + queryPlaceholder: 'Vui lòng nhập văn bản yêu cầu.', + run: 'CHẠY', + }, + result: 'Văn bản Đầu Ra', + datasetConfig: { + settingTitle: 'Cài đặt Truy xuất', + retrieveOneWay: { + title: 'N-to-1 Truy xuất', + description: 'Dựa trên ý định của người dùng và mô tả Kiến thức, Agent tự động chọn Kiến thức tốt nhất để truy vấn. Tốt nhất cho các ứng dụng có Kiến thức cụ thể, giới hạn.', + }, + retrieveMultiWay: { + title: 'Truy xuất đa hướng', + description: 'Dựa trên ý định của người dùng, truy vấn qua tất cả Kiến thức, truy xuất văn bản liên quan từ nhiều nguồn và chọn ra kết quả tốt nhất phù hợp với truy vấn của người dùng sau khi sắp xếp lại. Yêu cầu cấu hình của API Rerank model.', + }, + rerankModelRequired: 'Rerank model là bắt buộc', + params: 'Tham số', + top_k: 'Top K', + top_kTip: 'Sử dụng để lọc các phần chính xác nhất với các câu hỏi của người dùng. Hệ thống cũng sẽ tự động điều chỉnh giá trị của Top K, theo max_tokens của mô hình đã chọn.', + score_threshold: 'Ngưỡng Điểm', + score_thresholdTip: 'Sử dụng để thiết lập ngưỡng tương đồng cho việc lọc các phần.', + retrieveChangeTip: 'Thay đổi chế độ chỉ mục và chế độ truy xuất có thể ảnh hưởng đến các ứng dụng liên quan đến Kiến thức này.', + }, + debugAsSingleModel: 'Gỡ Lỗi như Một Mô hình', + debugAsMultipleModel: 'Gỡ Lỗi như Nhiều Mô hình', + duplicateModel: 'Sao chép', + publishAs: 'Xuất bản dưới dạng', + assistantType: { + name: 'Loại Trợ lý', + chatAssistant: { + name: 'Trợ lý Cơ bản', + description: 'Xây dựng một trợ lý dựa trên trò chuyện sử dụng một Mô hình Ngôn ngữ Lớn.', + }, + agentAssistant: { + name: 'Trợ lý Tác nhân', + description: 'Xây dựng một Tác nhân thông minh có thể tự động chọn các công cụ để hoàn thành các nhiệm vụ.', + }, + }, + agent: { + agentMode: 'Chế độ Tác nhân', + agentModeDes: 'Thiết lập loại chế độ suy luận cho tác nhân', + agentModeType: { + ReACT: 'ReAct', + functionCall: 'Gọi Hàm', + }, + setting: { + name: 'Thiết lập Tác nhân', + description: 'Thiết lập Tác nhân cho phép thiết lập chế độ tác nhân và các tính năng nâng cao như các lời nhắc tích hợp sẵn, chỉ có sẵn trong loại Tác nhân.', + maximumIterations: { + name: 'Số Lần Lặp Tối đa', + description: 'Giới hạn số lần lặp một trợ lý tác nhân có thể thực hiện', + }, + }, + buildInPrompt: 'Lời Nhắc Tích Hợp', + firstPrompt: 'Tiền-lời Nhắc Đầu Tiên', + nextIteration: 'Lần Lặp Tiếp Theo', + promptPlaceholder: 'Viết tiền-lời nhắc của bạn ở đây', + tools: { + name: 'Công cụ', + description: 'Sử dụng công cụ có thể mở rộng các khả năng của LLM, như tìm kiếm trên internet hoặc thực hiện các phép tính khoa học', + enabled: 'Đã kích hoạt', + }, + }, +} + +export default translation diff --git a/web/i18n/vi-VN/app-log.ts b/web/i18n/vi-VN/app-log.ts new file mode 100644 index 0000000000..318eb42b86 --- /dev/null +++ b/web/i18n/vi-VN/app-log.ts @@ -0,0 +1,69 @@ +const translation = { + title: 'Nhật ký', + description: 'Nhật ký ghi lại trạng thái hoạt động của ứng dụng, bao gồm đầu vào của người dùng và phản hồi của trí tuệ nhân tạo.', + dateTimeFormat: 'MM/DD/YYYY hh:mm A', + table: { + header: { + time: 'Thời gian', + endUser: 'Người dùng cuối', + input: 'Đầu vào', + output: 'Đầu ra', + summary: 'Tiêu đề', + messageCount: 'Số lượng tin nhắn', + userRate: 'Tỷ lệ người dùng', + adminRate: 'Tỷ lệ quản trị', + }, + pagination: { + previous: 'Trước', + next: 'Tiếp', + }, + empty: { + noChat: 'Chưa có cuộc trò chuyện', + noOutput: 'Không có đầu ra', + element: { + title: 'Có ai ở đó không?', + content: 'Quan sát và ghi chú các tương tác giữa người dùng cuối và ứng dụng trí tuệ nhân tạo ở đây để liên tục cải thiện độ chính xác của trí tuệ nhân tạo. Bạn có thể thử chia sẻ hoặc kiểm tra ứng dụng Web của mình, sau đó quay lại trang này.', + }, + }, + }, + detail: { + time: 'Thời gian', + conversationId: 'ID Cuộc trò chuyện', + promptTemplate: 'Mẫu Nhắc nhở', + promptTemplateBeforeChat: 'Mẫu Nhắc nhở Trước Cuộc trò chuyện · Như Tin nhắn Hệ thống', + annotationTip: 'Cải thiện Được Đánh Dấu bởi {{user}}', + timeConsuming: '', + second: 'giây', + tokenCost: 'Token đã tiêu', + loading: 'đang tải', + operation: { + like: 'thích', + dislike: 'không thích', + addAnnotation: 'Thêm Cải thiện', + editAnnotation: 'Chỉnh sửa Cải thiện', + annotationPlaceholder: 'Nhập câu trả lời mong muốn mà bạn muốn trí tuệ nhân tạo trả lời, có thể được sử dụng cho việc điều chỉnh mô hình và cải thiện liên tục chất lượng tạo văn bản trong tương lai.', + }, + variables: 'Biến', + uploadImages: 'Ảnh đã tải lên', + }, + filter: { + period: { + today: 'Hôm nay', + last7days: '7 Ngày qua', + last4weeks: '4 Tuần qua', + last3months: '3 Tháng qua', + last12months: '12 Tháng qua', + monthToDate: 'Từ Đầu tháng đến nay', + quarterToDate: 'Từ Đầu quý đến nay', + yearToDate: 'Từ Đầu năm đến nay', + allTime: 'Tất cả thời gian', + }, + annotation: { + all: 'Tất cả', + annotated: 'Cải thiện Đã Đánh Dấu ({{count}} mục)', + not_annotated: 'Chưa Đánh Dấu', + }, + }, +} + +export default translation diff --git a/web/i18n/vi-VN/app-overview.ts b/web/i18n/vi-VN/app-overview.ts new file mode 100644 index 0000000000..23a256adbe --- /dev/null +++ b/web/i18n/vi-VN/app-overview.ts @@ -0,0 +1,139 @@ +const translation = { + welcome: { + firstStepTip: 'Để bắt đầu,', + enterKeyTip: 'nhập khóa API OpenAI của bạn bên dưới', + getKeyTip: 'Lấy khóa API của bạn từ bảng điều khiển OpenAI', + placeholder: 'Khóa API OpenAI của bạn (ví dụ: sk-xxxx)', + }, + apiKeyInfo: { + cloud: { + trial: { + title: 'Bạn đang sử dụng hạn mức thử nghiệm của {{providerName}}.', + description: 'Hạn mức thử nghiệm được cung cấp cho việc thử nghiệm của bạn. Trước khi hạn mức cuộc gọi thử nghiệm được sử dụng hết, vui lòng thiết lập nhà cung cấp mô hình của riêng bạn hoặc mua thêm hạn mức.', + }, + exhausted: { + title: 'Hạn mức thử nghiệm của bạn đã được sử dụng hết, vui lòng thiết lập APIKey của bạn.', + description: 'Hạn mức thử nghiệm của bạn đã được sử dụng hết. Vui lòng thiết lập nhà cung cấp mô hình của riêng bạn hoặc mua thêm hạn mức.', + }, + }, + selfHost: { + title: { + row1: 'Để bắt đầu,', + row2: 'thiết lập nhà cung cấp mô hình của bạn trước.', + }, + }, + callTimes: 'Số lần gọi', + usedToken: 'Token đã sử dụng', + setAPIBtn: 'Đi đến thiết lập nhà cung cấp mô hình', + tryCloud: 'Hoặc thử phiên bản đám mây của Dify với báo giá miễn phí', + }, + overview: { + title: 'Tổng quan', + appInfo: { + explanation: 'WebApp Trí tuệ nhân tạo Sẵn sàng sử dụng', + accessibleAddress: 'URL Công cộng', + preview: 'Xem trước', + regenerate: 'Tạo lại', + preUseReminder: 'Vui lòng kích hoạt WebApp trước khi tiếp tục.', + settings: { + entry: 'Cài đặt', + title: 'Cài đặt WebApp', + webName: 'Tên WebApp', + webDesc: 'Mô tả WebApp', + webDescTip: 'Văn bản này sẽ được hiển thị trên phía máy khách, cung cấp hướng dẫn cơ bản về cách sử dụng ứng dụng', + webDescPlaceholder: 'Nhập mô tả của WebApp', + language: 'Ngôn ngữ', + more: { + entry: 'Hiển thị thêm cài đặt', + copyright: 'Bản quyền', + copyRightPlaceholder: 'Nhập tên tác giả hoặc tổ chức', + privacyPolicy: 'Chính sách Bảo mật', + privacyPolicyPlaceholder: 'Nhập liên kết chính sách bảo mật', + privacyPolicyTip: 'Giúp người truy cập hiểu về dữ liệu mà ứng dụng thu thập, xem Chính sách Bảo mật của Dify.', + }, + }, + embedded: { + entry: 'Nhúng', + title: 'Nhúng vào trang web', + explanation: 'Chọn cách nhúng ứng dụng trò chuyện vào trang web của bạn', + iframe: 'Để thêm ứng dụng trò chuyện ở bất kỳ đâu trên trang web của bạn, thêm iframe này vào mã html của bạn.', + scripts: 'Để thêm ứng dụng trò chuyện vào phía dưới bên phải của trang web của bạn, thêm mã này vào mã html của bạn.', + chromePlugin: 'Cài đặt Phần mở rộng Chrome Dify Chatbot', + copied: 'Đã sao chép', + copy: 'Sao chép', + }, + qrcode: { + title: 'Mã QR để chia sẻ', + scan: 'Quét để chia sẻ ứng dụng', + download: 'Tải về Mã QR', + }, + customize: { + way: 'cách', + entry: 'Tùy chỉnh', + title: 'Tùy chỉnh WebApp Trí tuệ nhân tạo', + explanation: 'Bạn có thể tùy chỉnh giao diện trước của ứng dụng Web để phù hợp với kịch bản và nhu cầu phong cách của bạn.', + way1: { + name: 'Fork mã nguồn máy khách, chỉnh sửa và triển khai lên Vercel (được khuyến nghị)', + step1: 'Fork mã nguồn máy khách và chỉnh sửa', + step1Tip: 'Nhấn vào đây để fork mã nguồn vào tài khoản GitHub của bạn và chỉnh sửa mã', + step1Operation: 'Dify-WebClient', + step2: 'Triển khai lên Vercel', + step2Tip: 'Nhấn vào đây để nhập kho vào Vercel và triển khai', + step2Operation: 'Nhập kho', + step3: 'Cấu hình biến môi trường', + step3Tip: 'Thêm các biến môi trường sau vào Vercel', + }, + way2: { + name: 'Viết mã phía máy khách để gọi API và triển khai nó lên máy chủ', + operation: 'Tài liệu', + }, + }, + }, + apiInfo: { + title: 'API Dịch vụ Backend', + explanation: 'Dễ dàng tích hợp vào ứng dụng của bạn', + accessibleAddress: 'Điểm cuối API Dịch vụ', + doc: 'Tài liệu Tham khảo API', + }, + status: { + running: 'Đang hoạt động', + disable: 'Tắt', + }, + }, + analysis: { + title: 'Phân tích', + ms: 'ms', + tokenPS: 'Token/s', + totalMessages: { + title: 'Tổng số tin nhắn', + explanation: 'Số lượt tương tác AI hàng ngày; không bao gồm kỹ thuật kịch bản / gỡ lỗi.', + }, + activeUsers: { + title: 'Người dùng hoạt động', + explanation: 'Người dùng duy nhất tham gia trò chuyện với AI; không bao gồm kỹ thuật kịch bản / gỡ lỗi.', + }, + tokenUsage: { + title: 'Sử dụng Token', + explanation: 'Phản ánh việc sử dụng token hàng ngày của mô hình ngôn ngữ cho ứng dụng, hữu ích cho mục đích kiểm soát chi phí.', + consumed: 'Đã tiêu', + }, + avgSessionInteractions: { + title: 'Trung bình Tương tác trong phiên', + explanation: 'Số lượt giao tiếp giữa người dùng và AI liên tục; cho các ứng dụng dựa trên cuộc trò chuyện.', + }, + userSatisfactionRate: { + title: 'Tỷ lệ Hài lòng của Người dùng', + explanation: 'Số lượng thích cho mỗi 1.000 tin nhắn. Điều này cho thấy tỷ lệ phản hồi mà người dùng rất hài lòng.', + }, + avgResponseTime: { + title: 'Trung bình Thời gian Phản hồi', + explanation: 'Thời gian (ms) để AI xử lý / phản hồi; cho các ứng dụng dựa trên văn bản.', + }, + tps: { + title: 'Tốc độ Đầu ra Token', + explanation: 'Đo hiệu suất của LLM. Đếm tốc độ đầu ra Token của LLM từ khi bắt đầu yêu cầu cho đến khi hoàn thành đầu ra.', + }, + }, +} + +export default translation diff --git a/web/i18n/vi-VN/app.ts b/web/i18n/vi-VN/app.ts new file mode 100644 index 0000000000..b321f53d3e --- /dev/null +++ b/web/i18n/vi-VN/app.ts @@ -0,0 +1,54 @@ +const translation = { + createApp: 'Tạo ứng dụng mới', + types: { + all: 'Tất cả', + assistant: 'Trợ lý', + completion: 'Hoàn thành', + }, + modes: { + completion: 'Máy tạo văn bản', + chat: 'Trợ lý Cơ bản', + }, + createFromConfigFile: 'Tạo ứng dụng từ tệp cấu hình', + deleteAppConfirmTitle: 'Xóa ứng dụng này?', + deleteAppConfirmContent: + 'Việc xóa ứng dụng là không thể đảo ngược. Người dùng sẽ không thể truy cập vào ứng dụng của bạn nữa, và tất cả cấu hình và nhật ký mẫu sẽ bị xóa vĩnh viễn.', + appDeleted: 'Ứng dụng đã bị xóa', + appDeleteFailed: 'Xóa ứng dụng thất bại', + join: 'Tham gia cộng đồng', + communityIntro: + 'Thảo luận với các thành viên nhóm, người đóng góp và nhà phát triển trên các kênh khác nhau.', + roadmap: 'Xem lộ trình của chúng tôi', + appNamePlaceholder: 'Vui lòng nhập tên của ứng dụng', + newApp: { + startToCreate: 'Hãy bắt đầu với ứng dụng mới của bạn', + captionName: 'Biểu tượng & tên ứng dụng', + captionAppType: 'Bạn muốn tạo loại ứng dụng nào?', + previewDemo: 'Xem trước demo', + chatApp: 'Trợ lý', + chatAppIntro: + 'Tôi muốn xây dựng một ứng dụng dựa trên trò chuyện. Ứng dụng này sử dụng định dạng câu hỏi và trả lời, cho phép nhiều vòng trò chuyện liên tục.', + agentAssistant: 'Trợ lý Mới cho Đại lý', + completeApp: 'Máy Tạo Văn Bản', + completeAppIntro: + 'Tôi muốn tạo một ứng dụng tạo ra văn bản chất lượng cao dựa trên các mẫu, như việc tạo bài viết, tóm tắt, dịch thuật và nhiều hơn nữa.', + showTemplates: 'Tôi muốn chọn từ một mẫu', + hideTemplates: 'Quay lại lựa chọn chế độ', + Create: 'Tạo', + Cancel: 'Hủy', + nameNotEmpty: 'Tên không được để trống', + appTemplateNotSelected: 'Vui lòng chọn một mẫu', + appTypeRequired: 'Vui lòng chọn loại ứng dụng', + appCreated: 'Ứng dụng đã được tạo', + appCreateFailed: 'Không thể tạo ứng dụng', + }, + editApp: { + startToEdit: 'Chỉnh sửa ứng dụng', + }, + emoji: { + ok: 'Đồng ý', + cancel: 'Hủy', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/billing.ts b/web/i18n/vi-VN/billing.ts new file mode 100644 index 0000000000..71abd8a884 --- /dev/null +++ b/web/i18n/vi-VN/billing.ts @@ -0,0 +1,115 @@ +const translation = { + currentPlan: 'Kế hoạch Hiện tại', + upgradeBtn: { + plain: 'Nâng cấp Kế hoạch', + encourage: 'Nâng cấp Ngay', + encourageShort: 'Nâng cấp', + }, + viewBilling: 'Quản lý thanh toán và đăng ký', + buyPermissionDeniedTip: 'Vui lòng liên hệ với quản trị viên doanh nghiệp của bạn để đăng ký', + plansCommon: { + title: 'Chọn một kế hoạch phù hợp với bạn', + yearlyTip: 'Nhận 2 tháng miễn phí khi đăng ký hàng năm!', + mostPopular: 'Phổ biến nhất', + planRange: { + monthly: 'Hàng tháng', + yearly: 'Hàng năm', + }, + month: 'tháng', + year: 'năm', + save: 'Tiết kiệm ', + free: 'Miễn phí', + currentPlan: 'Kế hoạch Hiện tại', + contractSales: 'Liên hệ bộ phận bán hàng', + contractOwner: 'Liên hệ quản lý nhóm', + startForFree: 'Bắt đầu miễn phí', + getStartedWith: 'Bắt đầu với ', + contactSales: 'Liên hệ Bán hàng', + talkToSales: 'Nói chuyện với Bộ phận Bán hàng', + modelProviders: 'Nhà cung cấp Mô hình', + teamMembers: 'Thành viên Nhóm', + buildApps: 'Xây dựng Ứng dụng', + vectorSpace: 'Không gian Vector', + vectorSpaceBillingTooltip: 'Mỗi 1MB có thể lưu trữ khoảng 1.2 triệu ký tự dữ liệu vector hóa (ước tính sử dụng OpenAI Embeddings, thay đổi tùy theo các mô hình).', + vectorSpaceTooltip: 'Không gian Vector là hệ thống bộ nhớ dài hạn cần thiết cho LLMs để hiểu dữ liệu của bạn.', + documentsUploadQuota: 'Hạn mức Tải lên Tài liệu', + documentProcessingPriority: 'Ưu tiên Xử lý Tài liệu', + documentProcessingPriorityTip: 'Để có ưu tiên xử lý tài liệu cao hơn, vui lòng nâng cấp kế hoạch của bạn.', + documentProcessingPriorityUpgrade: 'Xử lý nhiều dữ liệu với độ chính xác cao và tốc độ nhanh hơn.', + priority: { + 'standard': 'Tiêu chuẩn', + 'priority': 'Ưu tiên', + 'top-priority': 'Ưu tiên Cao nhất', + }, + logsHistory: 'Lịch sử Nhật ký', + customTools: 'Công cụ Tùy chỉnh', + unavailable: 'Không có sẵn', + days: 'ngày', + unlimited: 'Không giới hạn', + support: 'Hỗ trợ', + supportItems: { + communityForums: 'Diễn đàn cộng đồng', + emailSupport: 'Hỗ trợ qua email', + priorityEmail: 'Hỗ trợ qua email & chat ưu tiên', + logoChange: 'Thay đổi Logo', + SSOAuthentication: 'Xác thực SSO', + personalizedSupport: 'Hỗ trợ cá nhân hóa', + dedicatedAPISupport: 'Hỗ trợ API dành riêng', + customIntegration: 'Tích hợp và hỗ trợ tùy chỉnh', + ragAPIRequest: 'Yêu cầu API RAG', + bulkUpload: 'Tải lên tài liệu hàng loạt', + agentMode: 'Chế độ Đại lý', + workflow: 'Quy trình làm việc', + }, + comingSoon: 'Sắp ra mắt', + member: 'Thành viên', + memberAfter: 'Thành viên', + messageRequest: { + title: 'Số Lượng Tin Nhắn', + tooltip: 'Hạn mức triệu hồi tin nhắn cho các kế hoạch sử dụng mô hình OpenAI (ngoại trừ gpt4). Các tin nhắn vượt quá giới hạn sẽ sử dụng Khóa API OpenAI của bạn.', + }, + annotatedResponse: { + title: 'Hạn Mức Quota Phản hồi Đã được Ghi chú', + tooltip: 'Chỉnh sửa và ghi chú thủ công các phản hồi cung cấp khả năng trả lời câu hỏi chất lượng cao có thể tùy chỉnh cho các ứng dụng. (Chỉ áp dụng trong các ứng dụng trò chuyện)', + }, + ragAPIRequestTooltip: 'Đề cập đến số lượng cuộc gọi API triệu hồi chỉ khả năng xử lý cơ sở kiến thức của Dify.', + receiptInfo: 'Chỉ chủ nhóm và quản trị viên nhóm có thể đăng ký và xem thông tin thanh toán', + }, + plans: { + sandbox: { + name: 'Hộp Cát', + description: 'Thử nghiệm miễn phí 200 lần GPT', + includesTitle: 'Bao gồm:', + }, + professional: { + name: 'Chuyên nghiệp', + description: 'Dành cho cá nhân và các nhóm nhỏ để mở khóa nhiều sức mạnh với giá cả phải chăng.', + includesTitle: 'Tất cả trong kế hoạch miễn phí, cộng thêm:', + }, + team: { + name: 'Nhóm', + description: 'Hợp tác mà không giới hạn và tận hưởng hiệu suất hạng nhất.', + includesTitle: 'Tất cả trong kế hoạch Chuyên nghiệp, cộng thêm:', + }, + enterprise: { + name: 'Doanh nghiệp', + description: 'Nhận toàn bộ khả năng và hỗ trợ cho các hệ thống quan trọng cho nhiệm vụ quy mô lớn.', + includesTitle: 'Tất cả trong kế hoạch Nhóm, cộng thêm:', + }, + }, + vectorSpace: { + fullTip: 'Không gian Vector đã đầy.', + fullSolution: 'Nâng cấp kế hoạch của bạn để có thêm không gian.', + }, + apps: { + fullTipLine1: 'Nâng cấp kế hoạch của bạn để', + fullTipLine2: 'xây dựng thêm ứng dụng.', + }, + annotatedResponse: { + fullTipLine1: 'Nâng cấp kế hoạch của bạn để', + fullTipLine2: 'ghi chú thêm cuộc trò chuyện.', + quotaTitle: 'Hạn Mức Quota Phản hồi Đã được Ghi chú', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts new file mode 100644 index 0000000000..ed130d8fa0 --- /dev/null +++ b/web/i18n/vi-VN/common.ts @@ -0,0 +1,505 @@ +const translation = { + api: { + success: 'Thành công', + actionSuccess: 'Thành công', + saved: 'Đã lưu', + create: 'Tạo', + remove: 'Xóa', + }, + operation: { + create: 'Tạo mới', + confirm: 'Xác nhận', + cancel: 'Hủy bỏ', + clear: 'Xóa', + save: 'Lưu', + edit: 'Chỉnh sửa', + add: 'Thêm', + added: 'Đã thêm', + refresh: 'Làm mới', + reset: 'Đặt lại', + search: 'Tìm kiếm', + change: 'Thay đổi', + remove: 'Xóa', + send: 'Gửi', + copy: 'Sao chép', + lineBreak: 'Ngắt dòng', + sure: 'Tôi chắc chắn', + download: 'Tải xuống', + delete: 'Xóa', + settings: 'Cài đặt', + setup: 'Thiết lập', + getForFree: 'Nhận miễn phí', + reload: 'Tải lại', + ok: 'OK', + log: 'Nhật ký', + learnMore: 'Tìm hiểu thêm', + params: 'Tham số', + }, + placeholder: { + input: 'Vui lòng nhập', + select: 'Vui lòng chọn', + }, + voice: { + language: { + zhHans: 'Tiếng Trung', + enUS: 'Tiếng Anh', + deDE: 'Tiếng Đức', + frFR: 'Tiếng Pháp', + esES: 'Tiếng Tây Ban Nha', + itIT: 'Tiếng Ý', + thTH: 'Tiếng Thái', + idID: 'Tiếng Indonesia', + jaJP: 'Tiếng Nhật', + koKR: 'Tiếng Hàn', + ptBR: 'Tiếng Bồ Đào Nha', + ruRU: 'Tiếng Nga', + ukUA: 'Tiếng Ukraina', + }, + }, + unit: { + char: 'ký tự', + }, + actionMsg: { + noModification: 'Hiện không có sự thay đổi.', + modifiedSuccessfully: 'Chỉnh sửa thành công', + modifiedUnsuccessfully: 'Chỉnh sửa không thành công', + copySuccessfully: 'Đã sao chép thành công', + paySucceeded: 'Thanh toán thành công', + payCancelled: 'Thanh toán đã hủy', + generatedSuccessfully: 'Tạo thành công', + generatedUnsuccessfully: 'Tạo không thành công', + }, + model: { + params: { + temperature: 'Nhiệt độ', + temperatureTip: + 'Kiểm soát sự ngẫu nhiên: Giảm nhiệt độ dẫn đến ít kết quả hoàn thành ngẫu nhiên hơn. Khi nhiệt độ tiến gần về không, mô hình sẽ trở nên xác định và lặp lại.', + top_p: 'Top P', + top_pTip: + 'Kiểm soát đa dạng thông qua lấy mẫu nhân nhóm: 0.5 có nghĩa là nửa số tùy chọn có khả năng cao được xem xét.', + presence_penalty: 'Phạt sự hiện diện', + presence_penaltyTip: + 'Độ lớn của sự phạt cho các token mới dựa trên việc chúng có xuất hiện trong văn bản cho đến nay hay không.\nTăng khả năng của mô hình để nói về các chủ đề mới.', + frequency_penalty: 'Phạt tần suất', + frequency_penaltyTip: + 'Độ lớn của sự phạt cho các token mới dựa trên tần suất hiện tại của chúng trong văn bản cho đến nay.\nGiảm khả năng của mô hình để lặp lại cùng một dòng văn bản.', + max_tokens: 'Max token', + max_tokensTip: + 'Sử dụng để giới hạn độ dài tối đa của câu trả lời, theo token. \nGiá trị lớn có thể giới hạn không gian còn lại cho từ khóa khởi đầu, nhật ký trò chuyện và Kiến thức. \nKhuyến nghị đặt giá trị dưới hai phần ba của gpt-4-1106-preview, gpt-4-vision-preview max token (đầu vào 128k đầu ra 4k)', + maxTokenSettingTip: 'Cài đặt max token của bạn quá cao, có thể hạn chế không gian cho từ khóa, truy vấn và dữ liệu. Xem xét đặt nó dưới 2/3.', + setToCurrentModelMaxTokenTip: 'Max token được cập nhật đến 80% token tối đa của mô hình hiện tại {{maxToken}}.', + stop_sequences: 'Chuỗi dừng', + stop_sequencesTip: 'Lên đến bốn chuỗi nơi API sẽ dừng việc tạo ra các token tiếp theo. Văn bản được trả về sẽ không chứa chuỗi dừng.', + stop_sequencesPlaceholder: 'Nhập chuỗi và nhấn Tab', + }, + tone: { + Creative: 'Sáng tạo', + Balanced: 'Cân bằng', + Precise: 'Chính xác', + Custom: 'Tùy chỉnh', + }, + addMoreModel: 'Điều chỉnh cài đặt để thêm mô hình', + }, + menus: { + status: 'beta', + explore: 'Khám phá', + apps: 'Studio', + plugins: 'Plugins', + pluginsTips: 'Tích hợp các plugin bên thứ ba hoặc tạo ra các AI-Plugin tương thích với ChatGPT.', + datasets: 'Kiến thức', + datasetsTips: 'SẮP RA MẮT: Nhập dữ liệu văn bản của bạn hoặc viết dữ liệu theo thời gian thực thông qua Webhook để cải thiện ngữ cảnh LLM.', + newApp: 'Ứng dụng mới', + newDataset: 'Tạo Kiến thức', + tools: 'Công cụ', + }, + userProfile: { + settings: 'Cài đặt', + workspace: 'Không gian làm việc', + createWorkspace: 'Tạo Không gian làm việc', + helpCenter: 'Trung tâm trợ giúp', + roadmapAndFeedback: 'Lộ trình & Phản hồi', + community: 'Cộng đồng', + about: 'Về chúng tôi', + logout: 'Đăng xuất', + }, + settings: { + accountGroup: 'ACCOUNT', + workplaceGroup: 'WORKSPACE', + account: 'Tài khoản của tôi', + members: 'Thành viên', + billing: 'Thanh toán', + integrations: 'Tích hợp', + language: 'Ngôn ngữ', + provider: 'Nhà cung cấp mô hình', + dataSource: 'Nguồn dữ liệu', + plugin: 'Plugins', + apiBasedExtension: 'Mở rộng dựa trên API', + }, + account: { + avatar: 'Ảnh đại diện', + name: 'Tên', + email: 'Email', + password: 'Mật khẩu', + passwordTip: 'Bạn có thể đặt một mật khẩu cố định nếu bạn không muốn sử dụng mã đăng nhập tạm thời', + setPassword: 'Đặt mật khẩu', + resetPassword: 'Đặt lại mật khẩu', + currentPassword: 'Mật khẩu hiện tại', + newPassword: 'Mật khẩu mới', + confirmPassword: 'Xác nhận mật khẩu', + notEqual: 'Hai mật khẩu không giống nhau.', + langGeniusAccount: 'Tài khoản Dify', + langGeniusAccountTip: 'Tài khoản Dify của bạn và dữ liệu người dùng liên quan.', + editName: 'Chỉnh sửa Tên', + showAppLength: 'Hiển thị {{length}} ứng dụng', + }, + members: { + team: 'Nhóm', + invite: 'Mời', + name: 'TÊN', + lastActive: 'HOẠT ĐỘNG GẦN ĐÂY', + role: 'VAI TRÒ', + pending: 'Đang chờ...', + owner: 'Chủ sở hữu', + admin: 'Quản trị viên', + adminTip: 'Có thể xây dựng ứng dụng và quản lý cài đặt nhóm', + normal: 'Bình thường', + normalTip: 'Chỉ có thể sử dụng ứng dụng, không thể xây dựng ứng dụng', + inviteTeamMember: 'Mời thành viên nhóm', + inviteTeamMemberTip: 'Sau khi đăng nhập, họ có thể truy cập trực tiếp vào dữ liệu nhóm của bạn.', + email: 'Email', + emailInvalid: 'Định dạng Email không hợp lệ', + emailPlaceholder: 'Vui lòng nhập email', + sendInvite: 'Gửi Lời mời', + invitedAsRole: 'Được mời với vai trò {{role}}', + invitationSent: 'Lời mời đã được gửi', + invitationSentTip: 'Lời mời đã được gửi, và họ có thể đăng nhập vào Dify để truy cập vào dữ liệu nhóm của bạn.', + invitationLink: 'Liên kết Lời mời', + failedinvitationEmails: 'Dưới đây là danh sách email không gửi được lời mời', + ok: 'OK', + removeFromTeam: 'Xóa khỏi nhóm', + removeFromTeamTip: 'Sẽ xóa quyền truy cập nhóm', + setAdmin: 'Đặt làm quản trị viên', + setMember: 'Đặt thành viên bình thường', + disinvite: 'Hủy lời mời', + deleteMember: 'Xóa thành viên', + you: '(Bạn)', + }, + integrations: { + connected: 'Đã kết nối', + google: 'Google', + googleAccount: 'Đăng nhập bằng tài khoản Google', + github: 'GitHub', + githubAccount: 'Đăng nhập bằng tài khoản GitHub', + connect: 'Kết nối', + }, + language: { + displayLanguage: 'Ngôn ngữ hiển thị', + timezone: 'Múi giờ', + }, + provider: { + apiKey: 'Khóa API', + enterYourKey: 'Nhập khóa API của bạn ở đây', + invalidKey: 'Khóa API OpenAI không hợp lệ', + validatedError: 'Xác minh thất bại: ', + validating: 'Đang xác minh khóa...', + saveFailed: 'Lưu khóa API thất bại', + apiKeyExceedBill: 'Khóa API này không có lượng truy vấn khả dụng, vui lòng đọc', + addKey: 'Thêm Khóa', + comingSoon: 'Sắp Ra Mắt', + editKey: 'Chỉnh sửa', + invalidApiKey: 'Khóa API không hợp lệ', + azure: { + apiBase: 'Cơ sở API', + apiBasePlaceholder: 'URL cơ sở API của điểm cuối Azure OpenAI của bạn.', + apiKey: 'Khóa API', + apiKeyPlaceholder: 'Nhập khóa API của bạn ở đây', + helpTip: 'Tìm hiểu Dịch vụ Azure OpenAI', + }, + openaiHosted: { + openaiHosted: 'OpenAI đang lưu trữ', + onTrial: 'TRIÊN DÙNG THỬ', + exhausted: 'HẾT QUOTA', + desc: 'Dịch vụ lưu trữ OpenAI được cung cấp bởi Dify cho phép bạn sử dụng các mô hình như GPT-3.5. Trước khi hết lượng truy vấn dùng thử, bạn cần thiết lập các nhà cung cấp mô hình khác.', + callTimes: 'Số lần gọi', + usedUp: 'Quota dùng thử đã hết. Thêm nhà cung cấp Mô hình của riêng bạn.', + useYourModel: 'Hiện đang sử dụng nhà cung cấp Mô hình của riêng bạn.', + close: 'Đóng', + }, + anthropicHosted: { + anthropicHosted: 'Anthropic Claude', + onTrial: 'TRIÊN DÙNG THỬ', + exhausted: 'HẾT QUOTA', + desc: 'Mô hình mạnh mẽ, vượt trội trong một loạt các nhiệm vụ từ trò chuyện phức tạp và tạo nội dung sáng tạo đến hướng dẫn chi tiết.', + callTimes: 'Số lần gọi', + usedUp: 'Quota dùng thử đã hết. Thêm nhà cung cấp Mô hình của riêng bạn.', + useYourModel: 'Hiện đang sử dụng nhà cung cấp Mô hình của riêng bạn.', + close: 'Đóng', + }, + anthropic: { + using: 'Khả năng nhúng đang sử dụng', + enableTip: 'Để kích hoạt mô hình Anthrop, bạn cần ràng buộc với Dịch vụ OpenAI hoặc Azure OpenAI trước.', + notEnabled: 'Chưa được kích hoạt', + keyFrom: 'Nhận khóa API của bạn từ Anthrop', + }, + encrypted: { + front: 'Khóa API của bạn sẽ được mã hóa và lưu trữ bằng', + back: ' công nghệ.', + }, + }, + modelProvider: { + notConfigured: 'Mô hình hệ thống vẫn chưa được cấu hình hoàn toàn và một số chức năng có thể không khả dụng.', + systemModelSettings: 'Cài đặt Mô hình Hệ thống', + systemModelSettingsLink: 'Tại sao cần thiết phải thiết lập mô hình hệ thống?', + selectModel: 'Chọn mô hình của bạn', + setupModelFirst: 'Vui lòng thiết lập mô hình của bạn trước', + systemReasoningModel: { + key: 'Mô hình lập luận hệ thống', + tip: 'Thiết lập mô hình suy luận mặc định sẽ được sử dụng để tạo ứng dụng, cũng như các tính năng như việc tạo tên cuộc trò chuyện và đề xuất câu hỏi tiếp theo cũng sẽ sử dụng mô hình suy luận mặc định.', + }, + embeddingModel: { + key: 'Mô hình nhúng', + tip: 'Thiết lập mô hình mặc định cho việc xử lý nhúng tài liệu của Kiến thức, cả hai phương tiện truy xuất và nhập của Kiến thức đều sử dụng mô hình nhúng này cho xử lý vector hóa. Chuyển đổi sẽ làm cho kích thước vector giữa Kiến thức được nhập và câu hỏi không nhất quán, dẫn đến việc truy xuất thất bại. Để tránh truy xuất thất bại, vui lòng không chuyển đổi mô hình này tùy ý.', + required: 'Mô hình nhúng là bắt buộc', + }, + speechToTextModel: { + key: 'Mô hình Chuyển đổi Văn bản thành Tiếng nói', + tip: 'Thiết lập mô hình mặc định cho đầu vào chuyển đổi tiếng nói thành văn bản trong cuộc trò chuyện.', + }, + ttsModel: { + key: 'Mô hình Văn bản thành Tiếng nói', + tip: 'Thiết lập mô hình mặc định cho đầu vào văn bản thành tiếng nói trong cuộc trò chuyện.', + }, + rerankModel: { + key: 'Mô hình Sắp xếp lại', + tip: 'Mô hình sắp xếp lại sẽ sắp xếp lại danh sách tài liệu ứng cử viên dựa trên sự phù hợp ngữ nghĩa với truy vấn của người dùng, cải thiện kết quả của việc xếp hạng ngữ nghĩa', + }, + quota: 'Hạn mức', + searchModel: 'Mô hình tìm kiếm', + noModelFound: 'Không tìm thấy mô hình cho {{model}}', + models: 'Mô hình', + showMoreModelProvider: 'Hiển thị thêm nhà cung cấp mô hình', + selector: { + tip: 'Mô hình này đã bị xóa. Vui lòng thêm một mô hình hoặc chọn mô hình khác.', + emptyTip: 'Không có mô hình khả dụng', + emptySetting: 'Vui lòng vào cài đặt để cấu hình', + rerankTip: 'Vui lòng thiết lập mô hình sắp xếp lại', + }, + card: { + quota: 'QUOTA', + onTrial: 'Thử nghiệm', + paid: 'Đã thanh toán', + quotaExhausted: 'Quota đã hết', + callTimes: 'Số lần gọi', + tokens: 'Tokens', + buyQuota: 'Mua Quota', + priorityUse: 'Ưu tiên sử dụng', + removeKey: 'Remove API Key', + tip: 'Ưu tiên sẽ được trao cho hạn ngạch đã thanh toán. Hạn ngạch dùng thử sẽ được sử dụng sau khi hết hạn ngạch trả phí.', + }, + item: { + deleteDesc: 'Các mô hình {{modelName}} đang được sử dụng như là các mô hình lập luận hệ thống. Một số chức năng sẽ không khả dụng sau khi loại bỏ. Vui lòng xác nhận.', + freeQuota: 'QUYỀN LỢI MIỄN PHÍ', + }, + addApiKey: 'Thêm khóa API của bạn', + invalidApiKey: 'Khóa API không hợp lệ', + encrypted: { + front: 'Khóa API CỦA BẠN sẽ được mã hóa và lưu trữ bằng', + back: ' công nghệ.', + }, + freeQuota: { + howToEarn: 'Cách kiếm', + }, + addMoreModelProvider: 'THÊM NHÀ CUNG CẤP MÔ HÌNH', + addModel: 'Thêm Mô hình', + modelsNum: '{{num}} Mô hình', + showModels: 'Hiện Mô hình', + showModelsNum: 'Hiện {{num}} Mô hình', + collapse: 'Thu gọn', + config: 'Cấu hình', + modelAndParameters: 'Mô hình và Tham số', + model: 'Mô hình', + featureSupported: '{{feature}} được hỗ trợ', + callTimes: 'Số lần gọi', + credits: 'Tín dụng Tin nhắn', + buyQuota: 'Mua Quyền lợi', + getFreeTokens: 'Nhận mã thông báo miễn phí', + priorityUsing: 'Ưu tiên sử dụng', + deprecated: 'Đã lỗi thời', + confirmDelete: 'Xác nhận xóa?', + quotaTip: 'Số lượng mã thông báo miễn phí còn lại', + loadPresets: 'Tải Cài đặt trước', + parameters: 'THAM SỐ', + }, + dataSource: { + add: 'Thêm nguồn dữ liệu', + connect: 'Kết nối', + notion: { + title: 'Notion', + description: 'Sử dụng Notion như một nguồn dữ liệu cho Kiến thức.', + connectedWorkspace: 'Không gian làm việc đã kết nối', + addWorkspace: 'Thêm không gian làm việc', + connected: 'Đã kết nối', + disconnected: 'Đã ngắt kết nối', + changeAuthorizedPages: 'Thay đổi trang được ủy quyền', + pagesAuthorized: 'Các trang được ủy quyền', + sync: 'Đồng bộ', + remove: 'Xóa', + selector: { + pageSelected: 'Các trang đã chọn', + searchPages: 'Tìm kiếm trang...', + noSearchResult: 'Không có kết quả tìm kiếm', + addPages: 'Thêm trang', + preview: 'Xem trước', + }, + }, + }, + plugin: { + serpapi: { + apiKey: 'Khóa API', + apiKeyPlaceholder: 'Nhập khóa API của bạn', + keyFrom: 'Nhận khóa SerpAPI của bạn từ Trang tài khoản SerpAPI', + }, + }, + apiBasedExtension: { + title: 'Các tiện ích API cung cấp quản lý API tập trung, giúp cấu hình dễ dàng sử dụng trên các ứng dụng của Dify.', + link: 'Tìm hiểu cách phát triển Phần mở rộng API của riêng bạn.', + linkUrl: 'https://docs.dify.ai/features/extension/api_based_extension', + add: 'Thêm Phần mở rộng API', + selector: { + title: 'Phần mở rộng API', + placeholder: 'Vui lòng chọn phần mở rộng API', + manage: 'Quản lý Phần mở rộng API', + }, + modal: { + title: 'Thêm Phần mở rộng API', + editTitle: 'Chỉnh sửa Phần mở rộng API', + name: { + title: 'Tên', + placeholder: 'Vui lòng nhập tên', + }, + apiEndpoint: { + title: 'Điểm cuối API', + placeholder: 'Vui lòng nhập điểm cuối API', + }, + apiKey: { + title: 'Khóa API', + placeholder: 'Vui lòng nhập khóa API', + lengthError: 'Độ dài khóa API không được nhỏ hơn 5 ký tự', + }, + }, + type: 'Loại', + }, + about: { + changeLog: 'Nhật ký thay đổi', + updateNow: 'Cập nhật ngay', + nowAvailable: 'Dify {{version}} hiện đã có sẵn.', + latestAvailable: 'Dify {{version}} là phiên bản mới nhất hiện có.', + }, + appMenus: { + overview: 'Tổng quan', + promptEng: 'Orchestrate', + apiAccess: 'Truy cập API', + logAndAnn: 'Nhật ký & Thông báo', + }, + environment: { + testing: 'TESTING', + development: 'DEVELOPMENT', + }, + appModes: { + completionApp: 'Ứng dụng Tạo văn bản', + chatApp: 'Ứng dụng Trò chuyện', + }, + datasetMenus: { + documents: 'Tài liệu', + hitTesting: 'Kiểm tra truy vấn', + settings: 'Cài đặt', + emptyTip: 'Kiến thức chưa được liên kết, vui lòng đi đến ứng dụng hoặc plug-in để hoàn thành liên kết.', + viewDoc: 'Xem tài liệu', + relatedApp: 'các ứng dụng liên kết', + }, + voiceInput: { + speaking: 'Hãy nói...', + converting: 'Chuyển đổi thành văn bản...', + notAllow: 'micro không được ủy quyền', + }, + modelName: { + 'gpt-3.5-turbo': 'GPT-3.5-Turbo', + 'gpt-3.5-turbo-16k': 'GPT-3.5-Turbo-16K', + 'gpt-4': 'GPT-4', + 'gpt-4-32k': 'GPT-4-32K', + 'text-davinci-003': 'Text-Davinci-003', + 'text-embedding-ada-002': 'Text-Embedding-Ada-002', + 'whisper-1': 'Whisper-1', + 'claude-instant-1': 'Claude-Instant', + 'claude-2': 'Claude-2', + }, + chat: { + renameConversation: 'Đổi tên Cuộc trò chuyện', + conversationName: 'Tên cuộc trò chuyện', + conversationNamePlaceholder: 'Vui lòng nhập tên cuộc trò chuyện', + conversationNameCanNotEmpty: 'Yêu cầu nhập tên cuộc trò chuyện', + citation: { + title: 'THAM KHẢO', + linkToDataset: 'Liên kết tới Kiến thức', + characters: 'Ký tự:', + hitCount: 'Số lượt truy xuất:', + vectorHash: 'Vector hash:', + hitScore: 'Điểm truy xuất:', + }, + }, + promptEditor: { + placeholder: 'Viết từ khóa của bạn ở đây, nhập \'{\' để chèn một biến, nhập \'/\' để chèn một khối nội dung nhắc nhở', + context: { + item: { + title: 'Bối cảnh', + desc: 'Chèn mẫu bối cảnh', + }, + modal: { + title: '{{num}} Kiến thức trong Bối cảnh', + add: 'Thêm Bối cảnh', + footer: 'Bạn có thể quản lý các bối cảnh trong phần Bối cảnh bên dưới.', + }, + }, + history: { + item: { + title: 'Lịch sử Cuộc trò chuyện', + desc: 'Chèn mẫu tin nhắn lịch sử', + }, + modal: { + title: 'VÍ DỤ', + user: 'Xin chào', + assistant: 'Xin chào! Tôi có thể giúp gì cho bạn hôm nay?', + edit: 'Chỉnh sửa Tên Vai trò Cuộc trò chuyện', + }, + }, + variable: { + item: { + title: 'Biến & Công cụ Bên ngoài', + desc: 'Chèn Biến & Công cụ Bên ngoài', + }, + modal: { + add: 'Biến mới', + addTool: 'Công cụ mới', + }, + }, + query: { + item: { + title: 'Truy vấn', + desc: 'Chèn mẫu truy vấn người dùng', + }, + }, + existed: 'Đã tồn tại trong tin nhắn', + }, + imageUploader: { + uploadFromComputer: 'Tải lên từ Máy tính', + uploadFromComputerReadError: 'Đọc ảnh thất bại, vui lòng thử lại.', + uploadFromComputerUploadError: 'Tải ảnh lên thất bại, vui lòng tải lên lại.', + uploadFromComputerLimit: 'Ảnh tải lên không được vượt quá {{size}} MB', + pasteImageLink: 'Dán liên kết ảnh', + pasteImageLinkInputPlaceholder: 'Dán liên kết ảnh ở đây', + pasteImageLinkInvalid: 'Liên kết ảnh không hợp lệ', + imageUpload: 'Tải ảnh lên', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/custom.ts b/web/i18n/vi-VN/custom.ts new file mode 100644 index 0000000000..b36476b62a --- /dev/null +++ b/web/i18n/vi-VN/custom.ts @@ -0,0 +1,30 @@ +const translation = { + custom: 'Tùy chỉnh', + upgradeTip: { + prefix: 'Nâng cấp gói của bạn để', + suffix: 'tùy chỉnh thương hiệu của bạn.', + }, + webapp: { + title: 'Tùy chỉnh thương hiệu WebApp', + removeBrand: 'Xóa Được hỗ trợ bởi Dify', + changeLogo: 'Thay đổi Hình ảnh Thương hiệu Được hỗ trợ bởi', + changeLogoTip: 'Định dạng SVG hoặc PNG với kích thước tối thiểu là 40x40px', + }, + app: { + title: 'Tùy chỉnh thương hiệu tiêu đề ứng dụng', + changeLogoTip: 'Định dạng SVG hoặc PNG với kích thước tối thiểu là 80x80px', + }, + upload: 'Tải lên', + uploading: 'Đang tải lên', + uploadedFail: 'Tải ảnh lên thất bại, vui lòng tải lên lại.', + change: 'Thay đổi', + apply: 'Áp dụng', + restore: 'Khôi phục Mặc định', + customize: { + contactUs: ' liên hệ với chúng tôi ', + prefix: 'Để tùy chỉnh logo thương hiệu trong ứng dụng, vui lòng', + suffix: 'để nâng cấp lên phiên bản Doanh nghiệp.', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/dataset-creation.ts b/web/i18n/vi-VN/dataset-creation.ts new file mode 100644 index 0000000000..b29cd17c0b --- /dev/null +++ b/web/i18n/vi-VN/dataset-creation.ts @@ -0,0 +1,130 @@ +const translation = { + steps: { + header: { + creation: 'Tạo Kiến thức', + update: 'Thêm dữ liệu', + }, + one: 'Chọn nguồn dữ liệu', + two: 'Tiền xử lý và làm sạch văn bản', + three: 'Thực hiện và hoàn thành', + }, + error: { + unavailable: 'Kiến thức này không khả dụng', + }, + stepOne: { + filePreview: 'Xem trước tệp', + pagePreview: 'Xem trước trang', + dataSourceType: { + file: 'Nhập từ tệp văn bản', + notion: 'Đồng bộ từ Notion', + web: 'Đồng bộ từ trang web', + }, + uploader: { + title: 'Tải lên tệp văn bản', + button: 'Kéo và thả tệp, hoặc', + browse: 'Duyệt', + tip: 'Hỗ trợ {{supportTypes}}. Tối đa {{size}}MB mỗi tệp.', + validation: { + typeError: 'Loại tệp không được hỗ trợ', + size: 'Tệp quá lớn. Tối đa là {{size}}MB', + count: 'Không hỗ trợ tải lên nhiều tệp', + filesNumber: 'Bạn đã đạt đến giới hạn tải lên lô của {{filesNumber}} tệp.', + }, + cancel: 'Hủy bỏ', + change: 'Thay đổi', + failed: 'Tải lên thất bại', + }, + notionSyncTitle: 'Notion chưa được kết nối', + notionSyncTip: 'Để đồng bộ với Notion, trước tiên cần thiết lập kết nối với Notion.', + connect: 'Đi đến kết nối', + button: 'tiếp theo', + emptyDatasetCreation: 'Tôi muốn tạo Kiến thức trống', + modal: { + title: 'Tạo Kiến thức trống', + tip: 'Một Kiến thức trống sẽ không chứa tài liệu nào, và bạn có thể tải lên tài liệu bất kỳ lúc nào.', + input: 'Tên Kiến thức', + placeholder: 'Vui lòng nhập', + nameNotEmpty: 'Tên không thể để trống', + nameLengthInvaild: 'Tên phải từ 1 đến 40 ký tự', + cancelButton: 'Hủy bỏ', + confirmButton: 'Tạo', + failed: 'Tạo thất bại', + }, + }, + stepTwo: { + segmentation: 'Cài đặt đoạn', + auto: 'Tự động', + autoDescription: 'Tự động thiết lập quy tắc đoạn và tiền xử lý. Người dùng không quen thuộc được khuyến nghị chọn điều này.', + custom: 'Tùy chỉnh', + customDescription: 'Tùy chỉnh quy tắc đoạn, độ dài đoạn và quy tắc tiền xử lý, v.v.', + separator: 'Bộ phận xác định đoạn', + separatorPlaceholder: 'Ví dụ, dòng mới (\\\\n) hoặc bộ phận phân cách đặc biệt (như "***")', + maxLength: 'Độ dài tối đa của đoạn', + overlap: 'Chồng lấn đoạn', + overlapTip: 'Thiết lập chồng lấn đoạn có thể duy trì sự liên quan ngữ nghĩa giữa chúng, tăng cường hiệu ứng truy xuất. Đề xuất thiết lập từ 10% đến 25% của kích thước đoạn tối đa.', + overlapCheck: 'Chồng lấn đoạn không nên lớn hơn độ dài tối đa của đoạn', + rules: 'Quy tắc tiền xử lý văn bản', + removeExtraSpaces: 'Thay thế khoảng trắng liên tục, dòng mới và tab', + removeUrlEmails: 'Xóa tất cả URL và địa chỉ email', + removeStopwords: 'Loại bỏ các từ dừng như "một", "một", "những"', + preview: 'Xác nhận & Xem trước', + reset: 'Đặt lại', + indexMode: 'Chế độ chỉ số', + qualified: 'Chất lượng cao', + recommend: 'Khuyến nghị', + qualifiedTip: 'Gọi giao diện nhúng hệ thống mặc định để xử lý để cung cấp độ chính xác cao hơn khi người dùng truy vấn.', + warning: 'Vui lòng thiết lập khóa API nhà cung cấp mô hình trước.', + click: 'Đi đến cài đặt', + economical: 'Tiết kiệm', + economicalTip: 'Sử dụng các động cơ vector ngoại tuyến, chỉ số từ khóa, v.v. để giảm chính xác mà không tốn token', + QATitle: 'Phân đoạn theo định dạng Câu hỏi & Trả lời', + QATip: 'Bật tùy chọn này sẽ tiêu tốn thêm token', + QALanguage: 'Phân đoạn bằng', + emstimateCost: 'Ước lượng', + emstimateSegment: 'Đoạn ước tính', + segmentCount: 'đoạn', + calculating: 'Đang tính toán...', + fileSource: 'Tiền xử lý tài liệu', + notionSource: 'Tiền xử lý trang', + other: 'và những ', + fileUnit: ' tệp', + notionUnit: ' trang', + previousStep: 'Bước trước', + nextStep: 'Lưu & Xử lý', + save: 'Lưu & Xử lý', + cancel: 'Hủy bỏ', + sideTipTitle: 'Tại sao phải phân đoạn và tiền xử lý?', + sideTipP1: 'Khi xử lý dữ liệu văn bản, phân đoạn và làm sạch là hai bước tiền xử lý quan trọng.', + sideTipP2: 'Phân đoạn chia nhỏ văn bản dài thành đoạn để mô hình hiểu được tốt hơn. Điều này cải thiện chất lượng và tính liên quan của kết quả mô hình.', + sideTipP3: 'Làm sạch loại bỏ các ký tự và định dạng không cần thiết, làm cho Kiến thức trở nên sạch sẽ và dễ dàng phân tích hơn.', + sideTipP4: 'Phân đoạn và làm sạch đúng cách cải thiện hiệu suất của mô hình, cung cấp kết quả chính xác và có giá trị hơn.', + previewTitle: 'Xem trước', + previewTitleButton: 'Xem trước', + previewButton: 'Chuyển sang dạng Câu hỏi & Trả lời', + previewSwitchTipStart: 'Xem trước đoạn hiện tại đang ở định dạng văn bản, chuyển sang xem trước dạng câu hỏi và trả lời sẽ', + previewSwitchTipEnd: ' tiêu tốn thêm token', + characters: 'ký tự', + indexSettedTip: 'Để thay đổi phương pháp chỉ số, vui lòng đi tới ', + retrivalSettedTip: 'Để thay đổi phương pháp chỉ số, vui lòng đi tới ', + datasetSettingLink: 'cài đặt Kiến thức.', + }, + stepThree: { + creationTitle: '🎉 Kiến thức đã được tạo', + creationContent: 'Chúng tôi đã tự động đặt tên cho Kiến thức, bạn có thể sửa đổi nó bất kỳ lúc nào', + label: 'Tên Kiến thức', + additionTitle: '🎉 Tài liệu đã được tải lên', + additionP1: 'Tài liệu đã được tải lên Kiến thức', + additionP2: ', bạn có thể tìm thấy nó trong danh sách tài liệu của Kiến thức.', + stop: 'Dừng xử lý', + resume: 'Tiếp tục xử lý', + navTo: 'Đi đến tài liệu', + sideTipTitle: 'Tiếp theo là gì', + sideTipContent: 'Sau khi tài liệu hoàn thành chỉ mục, Kiến thức có thể được tích hợp vào ứng dụng như một ngữ cảnh, bạn có thể tìm cài đặt ngữ cảnh trong trang chỉ đạo đoạn. Bạn cũng có thể tạo nó như một plugin chỉ mục ChatGPT độc lập để phát hành.', + modelTitle: 'Bạn có chắc chắn muốn dừng việc nhúng?', + modelContent: 'Nếu bạn cần tiếp tục xử lý sau này, bạn sẽ tiếp tục từ vị trí bạn đã dừng lại.', + modelButtonConfirm: 'Xác nhận', + modelButtonCancel: 'Hủy bỏ', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/dataset-documents.ts b/web/i18n/vi-VN/dataset-documents.ts new file mode 100644 index 0000000000..3cf486a3c9 --- /dev/null +++ b/web/i18n/vi-VN/dataset-documents.ts @@ -0,0 +1,349 @@ +const translation = { + list: { + title: 'Tài liệu', + desc: 'Tất cả các tệp của Kiến thức được hiển thị ở đây, và toàn bộ Kiến thức có thể được liên kết với trích dẫn của Dify hoặc được lập chỉ mục thông qua plugin Chat.', + addFile: 'Thêm tệp', + addPages: 'Thêm Trang', + table: { + header: { + fileName: 'TÊN TỆP', + words: 'TỪ', + hitCount: 'SỐ LẦN TRUY VẤN', + uploadTime: 'THỜI GIAN TẢI LÊN', + status: 'TRẠNG THÁI', + action: 'HÀNH ĐỘNG', + }, + }, + action: { + uploadFile: 'Tải lên tệp mới', + settings: 'Cài đặt phân đoạn', + addButton: 'Thêm đoạn', + add: 'Thêm một đoạn', + batchAdd: 'Thêm hàng loạt', + archive: 'Lưu trữ', + unarchive: 'Khôi phục', + delete: 'Xóa', + enableWarning: 'Tệp được lưu trữ không thể được kích hoạt', + sync: 'Đồng bộ', + }, + index: { + enable: 'Kích hoạt', + disable: 'Vô hiệu hóa', + all: 'Tất cả', + enableTip: 'Tệp có thể được lập chỉ mục', + disableTip: 'Tệp không thể được lập chỉ mục', + }, + status: { + queuing: 'Đang chờ', + indexing: 'Đang lập chỉ mục', + paused: 'Tạm dừng', + error: 'Lỗi', + available: 'Có sẵn', + enabled: 'Đã kích hoạt', + disabled: 'Đã vô hiệu hóa', + archived: 'Đã lưu trữ', + }, + empty: { + title: 'Chưa có tài liệu', + upload: { + tip: 'Bạn có thể tải lên tệp, đồng bộ từ trang web, hoặc từ ứng dụng web như Notion, GitHub, v.v.', + }, + sync: { + tip: 'Dify sẽ định kỳ tải xuống tệp từ Notion của bạn và hoàn tất xử lý.', + }, + }, + delete: { + title: 'Bạn có chắc chắn muốn xóa?', + content: 'Nếu bạn cần tiếp tục xử lý sau này, bạn sẽ tiếp tục từ vị trí bạn đã dừng lại', + }, + batchModal: { + title: 'Thêm đoạn hàng loạt', + csvUploadTitle: 'Kéo và thả tệp CSV của bạn vào đây, hoặc ', + browse: 'duyệt', + tip: 'Tệp CSV phải tuân thủ cấu trúc sau:', + question: 'câu hỏi', + answer: 'trả lời', + contentTitle: 'nội dung đoạn', + content: 'nội dung', + template: 'Tải mẫu ở đây', + cancel: 'Hủy bỏ', + run: 'Chạy Hàng loạt', + runError: 'Chạy hàng loạt thất bại', + processing: 'Đang xử lý hàng loạt', + completed: 'Nhập đã hoàn thành', + error: 'Lỗi nhập', + ok: 'OK', + }, + }, + metadata: { + title: 'Siêu dữ liệu', + desc: 'Gắn nhãn siêu dữ liệu cho các tài liệu cho phép trí tuệ nhân tạo truy cập chúng một cách kịp thời và tiết lộ nguồn của các tài liệu tham chiếu cho người dùng.', + dateTimeFormat: 'D MMMM, YYYY hh:mm A', + docTypeSelectTitle: 'Vui lòng chọn loại tài liệu', + docTypeChangeTitle: 'Thay đổi loại tài liệu', + docTypeSelectWarning: + 'Nếu thay đổi loại tài liệu, các siêu dữ liệu hiện tại sẽ không được bảo toàn nữa', + firstMetaAction: 'Bắt đầu', + placeholder: { + add: 'Thêm ', + select: 'Chọn ', + }, + source: { + upload_file: 'Tải lên Tệp', + notion: 'Đồng bộ từ Notion', + github: 'Đồng bộ từ Github', + }, + type: { + book: 'Sách', + webPage: 'Trang Web', + paper: 'Bài báo', + socialMediaPost: 'Bài viết trên Mạng xã hội', + personalDocument: 'Tài liệu cá nhân', + businessDocument: 'Tài liệu doanh nghiệp', + IMChat: 'Trò chuyện qua tin nhắn', + wikipediaEntry: 'Bài viết Wikipedia', + notion: 'Đồng bộ từ Notion', + github: 'Đồng bộ từ Github', + technicalParameters: 'Tham số kỹ thuật', + }, + field: { + processRule: { + processDoc: 'Xử lý Tài liệu', + segmentRule: 'Quy tắc phân đoạn', + segmentLength: 'Chiều dài các đoạn', + processClean: 'Quy tắc làm sạch Văn bản', + }, + book: { + title: 'Tiêu đề', + language: 'Ngôn ngữ', + author: 'Tác giả', + publisher: 'Nhà xuất bản', + publicationDate: 'Ngày xuất bản', + ISBN: 'ISBN', + category: 'Danh mục', + }, + webPage: { + title: 'Tiêu đề', + url: 'URL', + language: 'Ngôn ngữ', + authorPublisher: 'Tác giả/Nhà xuất bản', + publishDate: 'Ngày xuất bản', + topicsKeywords: 'Chủ đề/Từ khóa', + description: 'Mô tả', + }, + paper: { + title: 'Tiêu đề', + language: 'Ngôn ngữ', + author: 'Tác giả', + publishDate: 'Ngày xuất bản', + journalConferenceName: 'Tên tạp chí/Hội nghị', + volumeIssuePage: 'Số/Trang', + DOI: 'DOI', + topicsKeywords: 'Chủ đề/Từ khóa', + abstract: 'Tóm tắt', + }, + socialMediaPost: { + platform: 'Nền tảng', + authorUsername: 'Tác giả/Tên người dùng', + publishDate: 'Ngày đăng', + postURL: 'URL Bài viết', + topicsTags: 'Chủ đề/Thẻ', + }, + personalDocument: { + title: 'Tiêu đề', + author: 'Tác giả', + creationDate: 'Ngày tạo', + lastModifiedDate: 'Ngày sửa đổi cuối cùng', + documentType: 'Loại tài liệu', + tagsCategory: 'Thẻ/Danh mục', + }, + businessDocument: { + title: 'Tiêu đề', + author: 'Tác giả', + creationDate: 'Ngày tạo', + lastModifiedDate: 'Ngày sửa đổi cuối cùng', + documentType: 'Loại tài liệu', + departmentTeam: 'Phòng ban/Nhóm', + }, + IMChat: { + chatPlatform: 'Nền tảng Trò chuyện', + chatPartiesGroupName: 'Đối tác Trò chuyện/Tên nhóm', + participants: 'Tham gia viên', + startDate: 'Ngày bắt đầu', + endDate: 'Ngày kết thúc', + topicsKeywords: 'Chủ đề/Từ khóa', + fileType: 'Loại tệp', + }, + wikipediaEntry: { + title: 'Tiêu đề', + language: 'Ngôn ngữ', + webpageURL: 'URL trang web', + editorContributor: 'Biên tập viên/Đóng góp viên', + lastEditDate: 'Ngày chỉnh sửa cuối cùng', + summaryIntroduction: 'Tóm tắt/Giới thiệu', + }, + notion: { + title: 'Tiêu đề', + language: 'Ngôn ngữ', + author: 'Tác giả', + createdTime: 'Thời gian tạo', + lastModifiedTime: 'Thời gian chỉnh sửa cuối cùng', + url: 'URL', + tag: 'Thẻ', + description: 'Mô tả', + }, + github: { + repoName: 'Tên kho lưu trữ', + repoDesc: 'Mô tả kho lưu trữ', + repoOwner: 'Chủ sở hữu kho lưu trữ', + fileName: 'Tên tệp', + filePath: 'Đường dẫn tệp', + programmingLang: 'Ngôn ngữ lập trình', + url: 'URL', + license: 'Giấy phép', + lastCommitTime: 'Thời gian commit cuối cùng', + lastCommitAuthor: 'Tác giả commit cuối cùng', + }, + originInfo: { + originalFilename: 'Tên tệp gốc', + originalFileSize: 'Kích thước tệp gốc', + uploadDate: 'Ngày tải lên', + lastUpdateDate: 'Ngày cập nhật cuối cùng', + source: 'Nguồn', + }, + technicalParameters: { + segmentSpecification: 'Đặc tả các đoạn', + segmentLength: 'Chiều dài các đoạn', + avgParagraphLength: 'Độ dài trung bình của đoạn', + paragraphs: 'Các đoạn', + hitCount: 'Số lần truy vấn', + embeddingTime: 'Thời gian nhúng', + embeddedSpend: 'Chi phí nhúng', + }, + }, + languageMap: { + zh: 'Tiếng Trung', + en: 'Tiếng Anh', + es: 'Tiếng Tây Ban Nha', + fr: 'Tiếng Pháp', + de: 'Tiếng Đức', + ja: 'Tiếng Nhật', + ko: 'Tiếng Hàn', + ru: 'Tiếng Nga', + ar: 'Tiếng Ả Rập', + pt: 'Tiếng Bồ Đào Nha', + it: 'Tiếng Ý', + nl: 'Tiếng Hà Lan', + pl: 'Tiếng Ba Lan', + sv: 'Tiếng Thụy Điển', + tr: 'Tiếng Thổ Nhĩ Kỳ', + he: 'Tiếng Do Thái', + hi: 'Tiếng Hindi', + da: 'Tiếng Đan Mạch', + fi: 'Tiếng Phần Lan', + no: 'Tiếng Na Uy', + hu: 'Tiếng Hungary', + el: 'Tiếng Hy Lạp', + cs: 'Tiếng Séc', + th: 'Tiếng Thái', + id: 'Tiếng Indonesia', + }, + categoryMap: { + book: { + fiction: 'Hư cấu', + biography: 'Tiểu sử', + history: 'Lịch sử', + science: 'Khoa học', + technology: 'Công nghệ', + education: 'Giáo dục', + philosophy: 'Triết học', + religion: 'Tôn giáo', + socialSciences: 'Khoa học xã hội', + art: 'Nghệ thuật', + travel: 'Du lịch', + health: 'Sức khỏe', + selfHelp: 'Tự giúp bản thân', + businessEconomics: 'Kinh doanh và kinh tế', + cooking: 'Nấu ăn', + childrenYoungAdults: 'Trẻ em và thanh thiếu niên', + comicsGraphicNovels: 'Truyện tranh và tiểu thuyết đồ họa', + poetry: 'Thơ', + drama: 'Kịch', + other: 'Khác', + }, + personalDoc: { + notes: 'Ghi chú', + blogDraft: 'Nháp Blog', + diary: 'Nhật ký', + researchReport: 'Báo cáo nghiên cứu', + bookExcerpt: 'Trích đoạn sách', + schedule: 'Lịch trình', + list: 'Danh sách', + projectOverview: 'Tổng quan dự án', + photoCollection: 'Bộ sưu tập ảnh', + creativeWriting: 'Viết sáng tạo', + codeSnippet: 'Đoạn mã', + designDraft: 'Bản dựng thiết kế', + personalResume: 'Sơ yếu lý lịch cá nhân', + other: 'Khác', + }, + businessDoc: { + meetingMinutes: 'Biên bản cuộc họp', + researchReport: 'Báo cáo nghiên cứu', + proposal: 'Đề xuất', + employeeHandbook: 'Sổ tay nhân viên', + trainingMaterials: 'Tài liệu đào tạo', + requirementsDocument: 'Tài liệu yêu cầu', + designDocument: 'Tài liệu thiết kế', + productSpecification: 'Thông số sản phẩm', + financialReport: 'Báo cáo tài chính', + marketAnalysis: 'Phân tích thị trường', + projectPlan: 'Kế hoạch dự án', + teamStructure: 'Cấu trúc nhóm', + policiesProcedures: 'Chính sách và quy trình', + contractsAgreements: 'Hợp đồng và thoả thuận', + emailCorrespondence: 'Thư tín', + other: 'Khác', + }, + }, + }, + embedding: { + processing: 'Đang nhúng...', + paused: 'Đã tạm dừng việc nhúng', + completed: 'Hoàn tất việc nhúng', + error: 'Lỗi khi nhúng', + docName: 'Đang xử lý văn bản', + mode: 'Quy tắc phân đoạn', + segmentLength: 'Chiều dài các đoạn', + textCleaning: 'Định nghĩa và làm sạch Văn bản', + segments: 'Đoạn', + highQuality: 'Chế độ chất lượng cao', + economy: 'Chế độ tiết kiệm', + estimate: 'Ước lượng tiêu thụ', + stop: 'Dừng xử lý', + resume: 'Tiếp tục xử lý', + automatic: 'Tự động', + custom: 'Tùy chỉnh', + previewTip: 'Xem trước đoạn sẽ có sẵn sau khi việc nhúng hoàn tất', + }, + segment: { + paragraphs: 'Đoạn', + keywords: 'Từ khóa', + addKeyWord: 'Thêm từ khóa', + keywordError: 'Độ dài tối đa của từ khóa là 20', + characters: 'ký tự', + hitCount: 'Số lần truy vấn', + vectorHash: 'Băm vector: ', + questionPlaceholder: 'thêm câu hỏi ở đây', + questionEmpty: 'Câu hỏi không thể trống', + answerPlaceholder: 'thêm câu trả lời ở đây', + answerEmpty: 'Câu trả lời không thể trống', + contentPlaceholder: 'thêm nội dung ở đây', + contentEmpty: 'Nội dung không thể trống', + newTextSegment: 'Đoạn văn mới', + newQaSegment: 'Đoạn câu hỏi & trả lời mới', + delete: 'Xóa đoạn này?', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/dataset-hit-testing.ts b/web/i18n/vi-VN/dataset-hit-testing.ts new file mode 100644 index 0000000000..044e34ab5f --- /dev/null +++ b/web/i18n/vi-VN/dataset-hit-testing.ts @@ -0,0 +1,28 @@ +const translation = { + title: 'Kiểm Tra Truy Vấn', + desc: 'Kiểm tra hiệu ứng đánh trúng của Kiến thức dựa trên văn bản truy vấn đã cho.', + dateTimeFormat: 'MM/DD/YYYY hh:mm A', + recents: 'Gần đây', + table: { + header: { + source: 'Nguồn', + text: 'Văn bản', + time: 'Thời gian', + }, + }, + input: { + title: 'Văn bản nguồn', + placeholder: 'Vui lòng nhập một văn bản, một câu khẳng định ngắn được khuyến nghị.', + countWarning: 'Tối đa 200 ký tự.', + indexWarning: 'Chỉ có trong Kiến thức chất lượng cao.', + testing: 'Kiểm tra', + }, + hit: { + title: 'RETRIEVAL PARAGRAPHS', + emptyTip: 'Kết quả Kiểm Tra Truy Vấn sẽ hiển thị ở đây', + }, + noRecentTip: 'Không có kết quả truy vấn gần đây ở đây', + viewChart: 'Xem VECTOR CHART', +} + +export default translation diff --git a/web/i18n/vi-VN/dataset-settings.ts b/web/i18n/vi-VN/dataset-settings.ts new file mode 100644 index 0000000000..87cf320bfa --- /dev/null +++ b/web/i18n/vi-VN/dataset-settings.ts @@ -0,0 +1,33 @@ +const translation = { + title: 'Cài đặt Kiến thức', + desc: 'Ở đây, bạn có thể sửa đổi các thuộc tính và phương pháp làm việc của Kiến thức.', + form: { + name: 'Tên Kiến thức', + namePlaceholder: 'Vui lòng nhập tên Kiến thức', + nameError: 'Tên không thể trống', + desc: 'Mô tả Kiến thức', + descInfo: 'Vui lòng viết mô tả văn bản rõ ràng để chỉ rõ nội dung của Kiến thức. Mô tả này sẽ được sử dụng làm cơ sở cho việc kết hợp khi lựa chọn từ nhiều Kiến thức cho sự suy luận.', + descPlaceholder: 'Miêu tả những gì có trong Kiến thức này. Một mô tả chi tiết cho phép AI truy cập nội dung của Kiến thức một cách kịp thời. Nếu trống, Dify sẽ sử dụng chiến lược hit mặc định.', + descWrite: 'Tìm hiểu cách viết mô tả Kiến thức tốt.', + permissions: 'Quyền hạn', + permissionsOnlyMe: 'Chỉ mình tôi', + permissionsAllMember: 'Tất cả thành viên nhóm', + indexMethod: 'Phương pháp chỉ mục', + indexMethodHighQuality: 'Chất lượng cao', + indexMethodHighQualityTip: 'Gọi giao diện nhúng của OpenAI để xử lý để cung cấp độ chính xác cao hơn khi người dùng truy vấn.', + indexMethodEconomy: 'Tiết kiệm', + indexMethodEconomyTip: 'Sử dụng các công cụ nhúng vector ngoại tuyến, chỉ mục từ khóa, v.v. để giảm độ chính xác mà không cần chi tiêu token', + embeddingModel: 'Mô hình nhúng', + embeddingModelTip: 'Để thay đổi mô hình nhúng, vui lòng đi tới ', + embeddingModelTipLink: 'Cài đặt', + retrievalSetting: { + title: 'Cài đặt truy vấn', + learnMore: 'Tìm hiểu thêm', + description: ' về phương pháp truy vấn.', + longDescription: ' về phương pháp truy vấn, bạn có thể thay đổi điều này bất kỳ lúc nào trong cài đặt Kiến thức.', + }, + save: 'Lưu', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/dataset.ts b/web/i18n/vi-VN/dataset.ts new file mode 100644 index 0000000000..90b46a6c2b --- /dev/null +++ b/web/i18n/vi-VN/dataset.ts @@ -0,0 +1,47 @@ +const translation = { + knowledge: 'Kiến thức', + documentCount: ' tài liệu', + wordCount: 'k từ', + appCount: ' ứng dụng liên kết', + createDataset: 'Tạo Kiến thức', + createDatasetIntro: 'Nhập dữ liệu văn bản của bạn hoặc viết dữ liệu theo thời gian thực qua Webhook để tăng cường ngữ cảnh LLM.', + deleteDatasetConfirmTitle: 'Xóa Kiến thức này?', + deleteDatasetConfirmContent: + 'Xóa Kiến thức là không thể đảo ngược. Người dùng sẽ không còn có khả năng truy cập Kiến thức của bạn nữa, và tất cả các cấu hình và nhật ký nhắc nhở sẽ bị xóa vĩnh viễn.', + datasetDeleted: 'Kiến thức đã bị xóa', + datasetDeleteFailed: 'Xóa Kiến thức không thành công', + didYouKnow: 'Bạn đã biết chưa?', + intro1: 'Kiến thức có thể được tích hợp vào ứng dụng Dify ', + intro2: 'như một ngữ cảnh', + intro3: ',', + intro4: 'hoặc nó ', + intro5: 'có thể được tạo', + intro6: ' dưới dạng một phần cắm chỉ mục ChatGPT độc lập để xuất bản', + unavailable: 'Không khả dụng', + unavailableTip: 'Mô hình nhúng không khả dụng, mô hình nhúng mặc định cần được cấu hình', + datasets: 'KIẾN THỨC', + datasetsApi: 'API', + retrieval: { + semantic_search: { + title: 'Tìm kiếm Vector', + description: 'Tạo các nhúng truy vấn và tìm kiếm phần văn bản giống nhất với biểu diễn vector của nó.', + }, + full_text_search: { + title: 'Tìm kiếm Toàn văn bản', + description: 'Chỉ mục tất cả các thuật ngữ trong tài liệu, cho phép người dùng tìm kiếm bất kỳ thuật ngữ nào và truy xuất phần văn bản liên quan chứa các thuật ngữ đó.', + }, + hybrid_search: { + title: 'Tìm kiếm Hybrid', + description: 'Thực hiện tìm kiếm toàn văn bản và tìm kiếm vector đồng thời, sắp xếp lại để chọn lựa phù hợp nhất với truy vấn của người dùng. Cấu hình của API mô hình Rerank là cần thiết.', + recommend: 'Gợi ý', + }, + invertedIndex: { + title: 'Chỉ mục Nghịch đảo', + description: 'Chỉ mục Nghịch đảo là một cấu trúc được sử dụng cho việc truy xuất hiệu quả. Tổ chức theo thuật ngữ, mỗi thuật ngữ trỏ đến tài liệu hoặc trang web chứa nó.', + }, + change: 'Thay đổi', + changeRetrievalMethod: 'Thay đổi phương pháp truy xuất', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/explore.ts b/web/i18n/vi-VN/explore.ts new file mode 100644 index 0000000000..eb71385f9f --- /dev/null +++ b/web/i18n/vi-VN/explore.ts @@ -0,0 +1,41 @@ +const translation = { + title: 'Khám phá', + sidebar: { + discovery: 'Khám phá', + chat: 'Chat', + workspace: 'Kho lưu trữ', + action: { + pin: 'Ghim', + unpin: 'Bỏ ghim', + rename: 'Đổi tên', + delete: 'Xóa', + }, + delete: { + title: 'Xóa ứng dụng', + content: 'Bạn có chắc chắn muốn xóa ứng dụng này không?', + }, + }, + apps: { + title: 'Khám phá Ứng dụng bởi Dify', + description: 'Sử dụng ngay các ứng dụng mẫu này hoặc tùy chỉnh các ứng dụng của bạn dựa trên các mẫu.', + allCategories: 'Tất cả Danh mục', + }, + appCard: { + addToWorkspace: 'Thêm vào Kho lưu trữ', + customize: 'Tùy chỉnh', + }, + appCustomize: { + title: 'Tạo ứng dụng từ {{name}}', + subTitle: 'Biểu tượng và tên ứng dụng', + nameRequired: 'Tên ứng dụng là bắt buộc', + }, + category: { + Assistant: 'Trợ lý', + Writing: 'Viết', + Translate: 'Dịch', + Programming: 'Lập trình', + HR: 'Nhân sự', + }, +} + +export default translation diff --git a/web/i18n/vi-VN/layout.ts b/web/i18n/vi-VN/layout.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/vi-VN/layout.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/vi-VN/login.ts b/web/i18n/vi-VN/login.ts new file mode 100644 index 0000000000..9f60027aa9 --- /dev/null +++ b/web/i18n/vi-VN/login.ts @@ -0,0 +1,59 @@ +const translation = { + pageTitle: 'Xin chào, hãy bắt đầu!👋', + welcome: 'Chào mừng bạn đến với Dify, vui lòng đăng nhập để tiếp tục.', + email: 'Địa chỉ Email', + emailPlaceholder: 'Email của bạn', + password: 'Mật khẩu', + passwordPlaceholder: 'Mật khẩu của bạn', + name: 'Tên người dùng', + namePlaceholder: 'Tên người dùng của bạn', + forget: 'Quên mật khẩu?', + signBtn: 'Đăng nhập', + installBtn: 'Cài đặt', + setAdminAccount: 'Thiết lập tài khoản quản trị', + setAdminAccountDesc: 'Quyền tối đa cho tài khoản quản trị, có thể được sử dụng để tạo ứng dụng và quản lý các nhà cung cấp LLM, v.v.', + createAndSignIn: 'Tạo và đăng nhập', + oneMoreStep: 'Một bước nữa', + createSample: 'Dựa trên thông tin này, chúng tôi sẽ tạo một ứng dụng mẫu cho bạn', + invitationCode: 'Mã mời', + invitationCodePlaceholder: 'Mã mời của bạn', + interfaceLanguage: 'Ngôn ngữ giao diện', + timezone: 'Múi giờ', + go: 'Đi đến Dify', + sendUsMail: 'Gửi email giới thiệu của bạn cho chúng tôi, và chúng tôi sẽ xử lý yêu cầu mời.', + acceptPP: 'Tôi đã đọc và chấp nhận chính sách bảo mật', + reset: 'Vui lòng chạy lệnh sau để đặt lại mật khẩu của bạn', + withGitHub: 'Tiếp tục với GitHub', + withGoogle: 'Tiếp tục với Google', + rightTitle: 'Mở khóa tiềm năng đầy đủ của LLM', + rightDesc: 'Dễ dàng xây dựng ứng dụng AI hấp dẫn mắt, có thể vận hành và cải thiện được.', + tos: 'Điều khoản dịch vụ', + pp: 'Chính sách bảo mật', + tosDesc: 'Bằng cách đăng ký, bạn đồng ý với', + donthave: 'Chưa có?', + invalidInvitationCode: 'Mã mời không hợp lệ', + accountAlreadyInited: 'Tài khoản đã được khởi tạo', + error: { + emailEmpty: 'Địa chỉ Email là bắt buộc', + emailInValid: 'Vui lòng nhập một địa chỉ email hợp lệ', + nameEmpty: 'Tên là bắt buộc', + passwordEmpty: 'Mật khẩu là bắt buộc', + passwordInvalid: 'Mật khẩu phải chứa chữ và số, và độ dài phải lớn hơn 8', + }, + license: { + tip: 'Trước khi bắt đầu Phiên bản Cộng đồng của Dify, hãy đọc', + link: 'Giấy phép mã nguồn mở trên GitHub', + }, + join: 'Tham gia', + joinTipStart: 'Mời bạn tham gia', + joinTipEnd: 'đội tại Dify', + invalid: 'Liên kết đã hết hạn', + explore: 'Khám phá Dify', + activatedTipStart: 'Bạn đã tham gia', + activatedTipEnd: 'đội', + activated: 'Đăng nhập ngay', + adminInitPassword: 'Mật khẩu khởi tạo quản trị viên', + validate: 'Xác thực', +} + +export default translation diff --git a/web/i18n/vi-VN/register.ts b/web/i18n/vi-VN/register.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/vi-VN/register.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/vi-VN/share-app.ts b/web/i18n/vi-VN/share-app.ts new file mode 100644 index 0000000000..14fc638fb3 --- /dev/null +++ b/web/i18n/vi-VN/share-app.ts @@ -0,0 +1,74 @@ +const translation = { + common: { + welcome: 'Chào mừng đến với', + appUnavailable: 'Ứng dụng không khả dụng', + appUnkonwError: 'Ứng dụng không khả dụng', + }, + chat: { + newChat: 'Trò chuyện mới', + pinnedTitle: 'Đã ghim', + unpinnedTitle: 'Trò chuyện', + newChatDefaultName: 'Cuộc trò chuyện mới', + resetChat: 'Đặt lại cuộc trò chuyện', + powerBy: 'Được cung cấp bởi', + prompt: 'Lời nhắc', + privatePromptConfigTitle: 'Cài đặt cuộc trò chuyện', + publicPromptConfigTitle: 'Lời nhắc ban đầu', + configStatusDes: 'Trước khi bắt đầu, bạn có thể chỉnh sửa cài đặt cuộc trò chuyện', + configDisabled: + 'Cài đặt của phiên trước đã được sử dụng cho phiên này.', + startChat: 'Bắt đầu trò chuyện', + privacyPolicyLeft: + 'Vui lòng đọc ', + privacyPolicyMiddle: + 'chính sách bảo mật', + privacyPolicyRight: + ' được cung cấp bởi nhà phát triển ứng dụng.', + deleteConversation: { + title: 'Xóa cuộc trò chuyện', + content: 'Bạn có chắc muốn xóa cuộc trò chuyện này không?', + }, + tryToSolve: 'Thử giải quyết', + temporarySystemIssue: 'Xin lỗi, có sự cố tạm thời của hệ thống.', + }, + generation: { + tabs: { + create: 'Chạy Một lần', + batch: 'Chạy Theo Lô', + saved: 'Đã Lưu', + }, + savedNoData: { + title: 'Bạn chưa lưu kết quả nào!', + description: 'Bắt đầu tạo nội dung và tìm kết quả đã lưu của bạn ở đây.', + startCreateContent: 'Bắt đầu tạo nội dung', + }, + title: 'Hoàn Thiện AI', + queryTitle: 'Nội dung truy vấn', + completionResult: 'Kết quả hoàn thiện', + queryPlaceholder: 'Viết nội dung truy vấn của bạn...', + run: 'Thực thi', + copy: 'Sao chép', + resultTitle: 'Hoàn Thiện AI', + noData: 'AI sẽ đưa ra điều bạn muốn ở đây.', + csvUploadTitle: 'Kéo và thả tệp CSV của bạn vào đây, hoặc ', + browse: 'duyệt', + csvStructureTitle: 'Tệp CSV phải tuân thủ cấu trúc sau:', + downloadTemplate: 'Tải xuống mẫu tại đây', + field: 'Trường', + batchFailed: { + info: '{{num}} thực thi thất bại', + retry: 'Thử lại', + outputPlaceholder: 'Không có nội dung đầu ra', + }, + errorMsg: { + empty: 'Vui lòng nhập nội dung vào tệp đã tải lên.', + fileStructNotMatch: 'Tệp CSV tải lên không khớp cấu trúc.', + emptyLine: 'Hàng {{rowIndex}} trống', + invalidLine: 'Hàng {{rowIndex}}: {{varName}} không thể để trống', + moreThanMaxLengthLine: 'Hàng {{rowIndex}}: {{varName}} không thể chứa nhiều hơn {{maxLength}} ký tự', + atLeastOne: 'Vui lòng nhập ít nhất một hàng vào tệp đã tải lên.', + }, + }, +} + +export default translation diff --git a/web/i18n/vi-VN/tools.ts b/web/i18n/vi-VN/tools.ts new file mode 100644 index 0000000000..20266953ee --- /dev/null +++ b/web/i18n/vi-VN/tools.ts @@ -0,0 +1,115 @@ +const translation = { + title: 'Công cụ', + createCustomTool: 'Tạo Công cụ Tùy chỉnh', + type: { + all: 'Tất cả', + builtIn: 'Tích hợp sẵn', + custom: 'Tùy chỉnh', + }, + contribute: { + line1: 'Tôi quan tâm đến việc ', + line2: 'đóng góp công cụ cho Dify.', + viewGuide: 'Xem hướng dẫn', + }, + author: 'Bởi', + auth: { + unauthorized: 'Chưa xác thực', + authorized: 'Đã xác thực', + setup: 'Thiết lập xác thực để sử dụng', + setupModalTitle: 'Thiết lập Xác thực', + setupModalTitleDescription: 'Sau khi cấu hình thông tin đăng nhập, tất cả các thành viên trong không gian làm việc có thể sử dụng công cụ này khi triển khai ứng dụng.', + }, + includeToolNum: 'Bao gồm {{num}} công cụ', + addTool: 'Thêm Công cụ', + createTool: { + title: 'Tạo Công cụ Tùy chỉnh', + editAction: 'Cấu hình', + editTitle: 'Chỉnh sửa Công cụ Tùy chỉnh', + name: 'Tên', + toolNamePlaceHolder: 'Nhập tên công cụ', + schema: 'Schema', + schemaPlaceHolder: 'Nhập schema OpenAPI của bạn vào đây', + viewSchemaSpec: 'Xem Chi tiết OpenAPI-Swagger', + importFromUrl: 'Nhập từ URL', + importFromUrlPlaceHolder: 'https://...', + urlError: 'Vui lòng nhập URL hợp lệ', + examples: 'Ví dụ', + exampleOptions: { + json: 'Thời tiết (JSON)', + yaml: 'Pet Store (YAML)', + blankTemplate: 'Mẫu Trống', + }, + availableTools: { + title: 'Công cụ Hiện có', + name: 'Tên', + description: 'Mô tả', + method: 'Phương thức', + path: 'Đường dẫn', + action: 'Hành động', + test: 'Kiểm tra', + }, + authMethod: { + title: 'Phương thức Xác thực', + type: 'Loại xác thực', + keyTooltip: 'Khóa Tiêu đề HTTP, Bạn có thể để trống nếu không biết là gì hoặc đặt nó thành một giá trị tùy chỉnh', + types: { + none: 'Không', + api_key: 'Khóa API', + apiKeyPlaceholder: 'Tên tiêu đề HTTP cho Khóa API', + apiValuePlaceholder: 'Nhập Khóa API', + }, + key: 'Khóa', + value: 'Giá trị', + }, + authHeaderPrefix: { + title: 'Loại Xác thực', + types: { + basic: 'Cơ bản', + bearer: 'Bearer', + custom: 'Tùy chỉnh', + }, + }, + privacyPolicy: 'Chính sách bảo mật', + privacyPolicyPlaceholder: 'Vui lòng nhập chính sách bảo mật', + }, + test: { + title: 'Kiểm tra', + parametersValue: 'Tham số & Giá trị', + parameters: 'Tham số', + value: 'Giá trị', + testResult: 'Kết quả Kiểm tra', + testResultPlaceholder: 'Kết quả kiểm tra sẽ hiển thị ở đây', + }, + thought: { + using: 'Sử dụng', + used: 'Đã sử dụng', + requestTitle: 'Yêu cầu đến', + responseTitle: 'Phản hồi từ', + }, + setBuiltInTools: { + info: 'Thông tin', + setting: 'Cài đặt', + toolDescription: 'Mô tả Công cụ', + parameters: 'Tham số', + string: 'chuỗi', + number: 'số', + required: 'Bắt buộc', + infoAndSetting: 'Thông tin & Cài đặt', + }, + noCustomTool: { + title: 'Không có công cụ tùy chỉnh!', + content: 'Thêm và quản lý các công cụ tùy chỉnh của bạn ở đây để xây dựng ứng dụng AI.', + createTool: 'Tạo Công cụ', + }, + noSearchRes: { + title: 'Xin lỗi, không có kết quả!', + content: 'Chúng tôi không thể tìm thấy bất kỳ công cụ nào phù hợp với tìm kiếm của bạn.', + reset: 'Thiết lập lại Tìm kiếm', + }, + builtInPromptTitle: 'Lời nhắc', + toolRemoved: 'Công cụ đã được loại bỏ', + notAuthorized: 'Công cụ chưa được xác thực', + howToGet: 'Cách nhận', +} + +export default translation From 849dc0560b7058456a8e0f4a651880832d538474 Mon Sep 17 00:00:00 2001 From: Laurent Magnien Date: Wed, 13 Mar 2024 11:20:55 +0100 Subject: [PATCH 025/200] feat: add French fr-FR (#2810) Co-authored-by: Laurent Magnien --- web/i18n/fr-FR/app-annotation.ts | 87 +++++ web/i18n/fr-FR/app-api.ts | 82 +++++ web/i18n/fr-FR/app-debug.ts | 409 +++++++++++++++++++++ web/i18n/fr-FR/app-log.ts | 69 ++++ web/i18n/fr-FR/app-overview.ts | 139 +++++++ web/i18n/fr-FR/app.ts | 54 +++ web/i18n/fr-FR/billing.ts | 115 ++++++ web/i18n/fr-FR/common.ts | 505 ++++++++++++++++++++++++++ web/i18n/fr-FR/custom.ts | 30 ++ web/i18n/fr-FR/dataset-creation.ts | 130 +++++++ web/i18n/fr-FR/dataset-documents.ts | 349 ++++++++++++++++++ web/i18n/fr-FR/dataset-hit-testing.ts | 28 ++ web/i18n/fr-FR/dataset-settings.ts | 33 ++ web/i18n/fr-FR/dataset.ts | 47 +++ web/i18n/fr-FR/explore.ts | 41 +++ web/i18n/fr-FR/layout.ts | 4 + web/i18n/fr-FR/login.ts | 59 +++ web/i18n/fr-FR/register.ts | 4 + web/i18n/fr-FR/share-app.ts | 74 ++++ web/i18n/fr-FR/tools.ts | 115 ++++++ web/i18n/language.ts | 2 +- 21 files changed, 2375 insertions(+), 1 deletion(-) create mode 100644 web/i18n/fr-FR/app-annotation.ts create mode 100644 web/i18n/fr-FR/app-api.ts create mode 100644 web/i18n/fr-FR/app-debug.ts create mode 100644 web/i18n/fr-FR/app-log.ts create mode 100644 web/i18n/fr-FR/app-overview.ts create mode 100644 web/i18n/fr-FR/app.ts create mode 100644 web/i18n/fr-FR/billing.ts create mode 100644 web/i18n/fr-FR/common.ts create mode 100644 web/i18n/fr-FR/custom.ts create mode 100644 web/i18n/fr-FR/dataset-creation.ts create mode 100644 web/i18n/fr-FR/dataset-documents.ts create mode 100644 web/i18n/fr-FR/dataset-hit-testing.ts create mode 100644 web/i18n/fr-FR/dataset-settings.ts create mode 100644 web/i18n/fr-FR/dataset.ts create mode 100644 web/i18n/fr-FR/explore.ts create mode 100644 web/i18n/fr-FR/layout.ts create mode 100644 web/i18n/fr-FR/login.ts create mode 100644 web/i18n/fr-FR/register.ts create mode 100644 web/i18n/fr-FR/share-app.ts create mode 100644 web/i18n/fr-FR/tools.ts diff --git a/web/i18n/fr-FR/app-annotation.ts b/web/i18n/fr-FR/app-annotation.ts new file mode 100644 index 0000000000..3926fe5e26 --- /dev/null +++ b/web/i18n/fr-FR/app-annotation.ts @@ -0,0 +1,87 @@ +const translation = { + title: 'Annotations', + name: 'Réponse à l\'Annotation', + editBy: 'Réponse éditée par {{author}}', + noData: { + title: 'Aucune annotation', + description: 'Vous pouvez modifier les annotations en déboguant l\'application, ou importer des annotations en masse ici pour une réponse de haute qualité.', + }, + table: { + header: { + question: 'question', + answer: 'réponse', + createdAt: 'créé à', + hits: 'clics', + actions: 'actions', + addAnnotation: 'Ajouter une Annotation', + bulkImport: 'Importation en Vrac', + bulkExport: 'Exportation en Vrac', + clearAll: 'Effacer toutes les annotations', + }, + }, + editModal: { + title: 'Modifier la réponse à l\'annotation', + queryName: 'Requête de l\'utilisateur', + answerName: 'Bot conteur', + yourAnswer: 'Votre réponse', + answerPlaceholder: 'Tapez votre réponse ici', + yourQuery: 'Votre requête', + queryPlaceholder: 'Tapez votre requête ici', + removeThisCache: 'Supprimez cette Annotation', + createdAt: 'Créé à', + }, + addModal: { + title: 'Ajouter une réponse d\'annotation', + queryName: 'Question', + answerName: 'Réponse', + answerPlaceholder: 'Tapez la réponse ici', + queryPlaceholder: 'Tapez la requête ici', + createNext: 'Ajoutez une autre réponse annotée', + }, + batchModal: { + title: 'Importation en masse', + csvUploadTitle: 'Glissez et déposez votre fichier CSV ici, ou', + browse: 'parcourir', + tip: 'Le fichier CSV doit se conformer à la structure suivante :', + question: 'question', + answer: 'réponse', + contentTitle: 'contenu du bloc', + content: 'contenu', + template: 'Téléchargez le modèle ici', + cancel: 'Annuler', + run: 'Exécuter le lot', + runError: 'L\'exécution du lot a échoué', + processing: 'Dans le traitement par lots', + completed: 'Importation terminée', + error: 'Erreur d\'Importation', + ok: 'D\'accord', + }, + errorMessage: { + answerRequired: 'Une réponse est requise', + queryRequired: 'La question est requise', + }, + viewModal: { + annotatedResponse: 'Réponse à l\'annotation', + hitHistory: 'Historique des coups', + hit: 'Clic', + hits: 'Clics', + noHitHistory: 'Aucun historique de recherche', + }, + hitHistoryTable: { + query: 'Requête', + match: 'Correspondance', + response: 'Réponse', + source: 'Source', + score: 'Score', + time: 'Temps', + }, + initSetup: { + title: 'Réponse d\'Annotation Configuration Initiale', + configTitle: 'Configuration de la Réponse d\'Annotation', + confirmBtn: 'Enregistrer & Activer', + configConfirmBtn: 'Enregistrer', + }, + embeddingModelSwitchTip: 'Modèle de vectorisation de texte d\'annotation, changer de modèles entraînera une ré-intégration, ce qui entraînera des coûts supplémentaires.', +} + +export default translation diff --git a/web/i18n/fr-FR/app-api.ts b/web/i18n/fr-FR/app-api.ts new file mode 100644 index 0000000000..c214e0a9c9 --- /dev/null +++ b/web/i18n/fr-FR/app-api.ts @@ -0,0 +1,82 @@ +const translation = { + apiServer: 'Serveur API', + apiKey: 'Clé API', + status: 'Statut', + disabled: 'Désactivé', + ok: 'En Service', + copy: 'Copier', + copied: 'Copié', + play: 'Jouer', + pause: 'Pause', + playing: 'Jouant', + merMaind: { + rerender: 'Refaire Rerendu', + }, + never: 'Jamais', + apiKeyModal: { + apiSecretKey: 'Clé secrète de l\'API', + apiSecretKeyTips: 'Pour prévenir l\'abus de l\'API, protégez votre clé API. Évitez de l\'utiliser comme du texte brut dans le code front-end. :)', + createNewSecretKey: 'Créer une nouvelle clé secrète', + secretKey: 'Clé Secrète', + created: 'CRÉÉ', + lastUsed: 'DERNIÈRE UTILISATION', + generateTips: 'Gardez cette clé dans un endroit sûr et accessible.', + }, + actionMsg: { + deleteConfirmTitle: 'Supprimer cette clé secrète ?', + deleteConfirmTips: 'Cette action ne peut pas être annulée.', + ok: 'D\'accord', + }, + completionMode: { + title: 'API de l\'application Completion', + info: 'Pour une génération de texte de haute qualité, telle que des articles, des résumés et des traductions, utilisez l\'API completion-messages avec l\'entrée de l\'utilisateur. La génération de texte repose sur les paramètres du modèle et les modèles de prompt définis dans', + createCompletionApi: 'Créer un Message de Fin', + createCompletionApiTip: 'Créez un message de fin pour soutenir le mode question-réponse.', + inputsTips: '(Facultatif) Fournissez des champs de saisie utilisateur sous forme de paires clé-valeur, correspondant aux variables dans Prompt Eng. La clé est le nom de la variable, la valeur est la valeur du paramètre. Si le type de champ est Sélection, la valeur sou', + queryTips: 'Contenu du texte saisi par l\'utilisateur.', + blocking: 'Type de blocage, en attente de l\'exécution pour terminer et renvoyer les résultats. (Les demandes peuvent être interrompues si le processus est long)', + streaming: 'retours en continu. Mise en œuvre de retours en continu basée sur SSE (Server-Sent Events).', + messageFeedbackApi: 'Retour de message (j\'aime)', + messageFeedbackApiTip: 'Évaluez les messages reçus au nom des utilisateurs finaux avec des likes ou des dislikes. Ces données sont visibles sur la page Logs & Annotations et sont utilisées pour le réglage fin des modèles futurs.', + messageIDTip: 'ID de message', + ratingTip: 'aimer ou ne pas aimer, null est annuler', + parametersApi: 'Obtenir des informations sur les paramètres de l\'application', + parametersApiTip: 'Récupérer les paramètres d\'entrée configurés, y compris les noms de variables, les noms de champs, les types et les valeurs par défaut. Généralement utilisé pour afficher ces champs dans un formulaire ou pour remplir les valeurs par défaut après le charg', + }, + chatMode: { + title: 'API de l\'application de chat', + info: 'Pour des applications conversationnelles polyvalentes utilisant un format Q&R, appelez l\'API de chat-messages pour initier le dialogue. Maintenez les conversations en cours en passant l\'ID de conversation retourné. Les paramètres de réponse et les modèles dépendent des paramètres de', + createChatApi: 'Créer un message de chat', + createChatApiTip: 'Créez un nouveau message de conversation ou continuez un dialogue existant.', + inputsTips: '(Facultatif) Fournir des champs de saisie utilisateur sous forme de paires clé-valeur, correspondant aux variables dans Prompt Eng. La clé est le nom de la variable, la valeur est la valeur du paramètre. Si le type de champ est Sélection, la valeur soumise', + queryTips: 'Contenu de la question/saisie de l\'utilisateur', + blocking: 'Type de blocage, en attente de l\'exécution pour terminer et renvoyer les résultats. (Les demandes peuvent être interrompues si le processus est long)', + streaming: 'retours en continu. Mise en œuvre de retours en continu basée sur SSE (Server-Sent Events).', + conversationIdTip: '(Optional) Conversation ID: leave empty for first-time conversation; pass conversation_id from context to continue dialogue.', + messageFeedbackApi: 'Message de retour d\'information de l\'utilisateur du terminal, comme', + messageFeedbackApiTip: 'Évaluez les messages reçus au nom des utilisateurs finaux avec des likes ou des dislikes. Ces données sont visibles sur la page Logs & Annotations et sont utilisées pour l\'ajustement futur du modèle.', + messageIDTip: 'ID de message', + ratingTip: 'aimer ou ne pas aimer, null est annuler', + chatMsgHistoryApi: 'Obtenez le message de l\'historique de chat', + chatMsgHistoryApiTip: 'La première page renvoie la dernière `limit` bar, qui est en ordre inverse.', + chatMsgHistoryConversationIdTip: 'ID de conversation', + chatMsgHistoryFirstId: 'ID du premier enregistrement de chat sur la page actuelle. La valeur par défaut est aucune.', + chatMsgHistoryLimit: 'Combien de chats sont renvoyés en une seule demande', + conversationsListApi: 'Obtenir la liste des conversations', + conversationsListApiTip: 'Obtient la liste des sessions de l\'utilisateur actuel. Par défaut, les 20 dernières sessions sont renvoyées.', + conversationsListFirstIdTip: 'L\'ID du dernier enregistrement sur la page actuelle, par défaut aucun.', + conversationsListLimitTip: 'Combien de chats sont renvoyés dans une seule requête', + conversationRenamingApi: 'Renommage de la conversation', + conversationRenamingApiTip: 'Renommez les conversations ; le nom est affiché dans les interfaces client multi-session.', + conversationRenamingNameTip: 'Nouveau nom', + parametersApi: 'Obtenir des informations sur les paramètres de l\'application', + parametersApiTip: 'Récupérer les paramètres d\'entrée configurés, y compris les noms de variables, les noms de champs, les types et les valeurs par défaut. Typiquement utilisé pour afficher ces champs dans un formulaire ou pour remplir les valeurs par défaut après le chargement du', + }, + develop: { + requestBody: 'Corps de la Requête', + pathParams: 'Params de chemin', + query: 'Requête', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/app-debug.ts b/web/i18n/fr-FR/app-debug.ts new file mode 100644 index 0000000000..adda4c9475 --- /dev/null +++ b/web/i18n/fr-FR/app-debug.ts @@ -0,0 +1,409 @@ +const translation = { + pageTitle: { + line1: 'INVITATION', + line2: 'Ingénierie', + }, + orchestrate: 'Orchestrer', + promptMode: { + simple: 'Passez en Mode Expert pour modifier l\'intégralité du PROMPT', + advanced: 'Mode Expert', + switchBack: 'Revenir en arrière', + advancedWarning: { + title: 'Vous êtes passé en Mode Expert, et une fois que vous modifiez le PROMPT, vous NE POUVEZ PAS revenir au mode basique.', + description: 'En mode Expert, vous pouvez modifier l\'intégralité du PROMPT.', + learnMore: 'En savoir plus', + ok: 'D\'accord', + }, + operation: { + addMessage: 'Ajouter un message', + }, + contextMissing: 'Le composant de contexte est manquant, l\'efficacité de la suggestion peut ne pas être bonne.', + }, + operation: { + applyConfig: 'Publier', + resetConfig: 'Réinitialiser', + debugConfig: 'Déboguer', + addFeature: 'Ajouter une fonctionnalité', + automatic: 'Automatique', + stopResponding: 'Arrêtez de répondre', + agree: 'comme', + disagree: 'déteste', + cancelAgree: 'Annuler comme', + cancelDisagree: 'Annuler le dislike', + userAction: 'Utilisateur', + }, + notSetAPIKey: { + title: 'La clé du fournisseur LLM n\'a pas été définie', + trailFinished: 'Parcours terminé', + description: 'La clé du fournisseur LLM n\'a pas été définie, et elle doit être définie avant le débogage.', + settingBtn: 'Aller aux paramètres', + }, + trailUseGPT4Info: { + title: 'Ne prend pas en charge gpt-4 pour le moment', + description: 'Utilisez gpt-4, veuillez définir la clé API.', + }, + feature: { + groupChat: { + title: 'Amélioration de chat', + description: 'Ajouter des paramètres de pré-conversation pour les applications peut améliorer l\'expérience utilisateur.', + }, + groupExperience: { + title: 'Amélioration de l\'expérience', + }, + conversationOpener: { + title: 'Convertisseurs de conversation', + description: 'Dans une application de chat, la première phrase que l\'IA prononce activement à l\'utilisateur est généralement utilisée comme message de bienvenue.', + }, + suggestedQuestionsAfterAnswer: { + title: 'Suivi', + description: 'La configuration de la suggestion des prochaines questions peut offrir aux utilisateurs une meilleure discussion.', + resDes: '3 suggestions pour la prochaine question de l\'utilisateur.', + tryToAsk: 'Essayez de demander', + }, + moreLikeThis: { + title: 'Plus comme ça', + description: 'Générez plusieurs textes à la fois, puis modifiez et continuez à générer', + generateNumTip: 'Nombre de chaque temps généré', + tip: 'L\'utilisation de cette fonctionnalité entraînera un surcoût de tokens supplémentaires', + }, + speechToText: { + title: 'Discours en Texte', + description: 'Une fois activé, vous pouvez utiliser l\'entrée vocale.', + resDes: 'La saisie vocale est activée', + }, + textToSpeech: { + title: 'Texte à la parole', + description: 'Une fois activé, le texte peut être converti en parole.', + resDes: 'La Texte à Audio est activée', + }, + citation: { + title: 'Citations et Attributions', + description: 'Une fois activé, affichez le document source et la section attribuée du contenu généré.', + resDes: 'Les citations et attributions sont activées', + }, + annotation: { + title: 'Réponse d\'Annotation', + description: 'Vous pouvez manuellement ajouter une réponse de haute qualité au cache pour une correspondance prioritaire avec des questions d\'utilisateur similaires.', + resDes: 'La réponse d\'annotation est activée', + scoreThreshold: { + title: 'Seuil de Score', + description: 'Utilisé pour définir le seuil de similarité pour la réponse d\'annotation.', + easyMatch: 'Correspondance Facile', + accurateMatch: 'Correspondance précise', + }, + matchVariable: { + title: 'Correspondance de Variable', + choosePlaceholder: 'Choisissez la variable correspondante', + }, + cacheManagement: 'Annotations', + cached: 'Annoté', + remove: 'Supprimer', + removeConfirm: 'Supprimer cette annotation ?', + add: 'Ajouter une annotation', + edit: 'Modifier l\'annotation', + }, + dataSet: { + title: 'Contexte', + noData: 'Vous pouvez importer des Connaissances comme contexte', + words: 'Mots', + textBlocks: 'Blocs de texte', + selectTitle: 'Sélectionnez la connaissance de référence', + selected: 'Connaissance sélectionnée', + noDataSet: 'Aucune connaissance trouvée', + toCreate: 'Aller à créer', + notSupportSelectMulti: 'Actuellement, ne prend en charge qu\'une seule Connaissance', + queryVariable: { + title: 'Variable de requête', + tip: 'Cette variable sera utilisée comme entrée de requête pour la récupération du contexte, obtenant des informations contextuelles liées à l\'entrée de cette variable.', + choosePlaceholder: 'Choisissez la variable de requête', + noVar: 'Aucune variable', + noVarTip: 'veuillez créer une variable sous la section Variables', + unableToQueryDataSet: 'Impossible de questionner la Connaissance', + unableToQueryDataSetTip: 'Impossible d\'interroger la Connaissance avec succès, veuillez choisir une variable de requête de contexte dans la section contexte.', + ok: 'D\'accord', + contextVarNotEmpty: 'la variable de requête de contexte ne peut pas être vide', + deleteContextVarTitle: 'Supprimer la variable "{{varName}}" ?', + deleteContextVarTip: 'Cette variable a été définie comme une variable de requête de contexte, et sa suppression affectera l\'utilisation normale de la Connaissance. Si vous devez toujours la supprimer, veuillez la re-sélectionner dans la section contexte.', + }, + }, + tools: { + title: 'Outils', + tips: 'Les outils fournissent une méthode d\'appel API standard, prenant en compte les entrées de l\'utilisateur ou des variables comme paramètres de requête pour interroger des données externes en tant que contexte.', + toolsInUse: '{{count}} outils en usage', + modal: { + title: 'Outil', + toolType: { + title: 'Type d\'outil', + placeholder: 'Veuillez sélectionner le type d\'outil', + }, + name: { + title: 'Nom', + placeholder: 'Veuillez entrer le nom', + }, + variableName: { + title: 'Nom de la Variable', + placeholder: 'Veuillez entrer le nom de la variable', + }, + }, + }, + conversationHistory: { + title: 'Historique de Conversation', + description: 'Définir les noms de préfixe pour les rôles de conversation', + tip: 'L\'historique de conversation n\'est pas activé, veuillez ajouter dans le prompt ci-dessus.', + learnMore: 'En savoir plus', + editModal: { + title: 'Modifier les noms de rôles de conversation', + userPrefix: 'Préfixe utilisateur', + assistantPrefix: 'Préfixe de l\'assistant', + }, + }, + toolbox: { + title: 'BOÎTE À OUTILS', + }, + moderation: { + title: 'Modération de contenu', + description: 'Sécurisez la sortie du modèle en utilisant l\'API de modération ou en conservant une liste de mots sensibles.', + allEnabled: 'Contenu Entrant/Sortant Activé', + inputEnabled: 'Contenu Activé', + outputEnabled: 'Contenu de SORTIE activé', + modal: { + title: 'Paramètres de modération de contenu', + provider: { + title: 'Fournisseur', + openai: 'Modération OpenAI', + openaiTip: { + prefix: 'La modération d\'OpenAI nécessite une clé API OpenAI configurée dans le', + suffix: '.', + }, + keywords: 'Mots-clés', + }, + keywords: { + tip: 'Un par ligne, séparés par des sauts de ligne. Jusqu\'à 100 caractères par ligne.', + placeholder: 'Un par ligne, séparé par des sauts de ligne', + line: 'Ligne', + }, + content: { + input: 'Modérer le contenu INPUT', + output: 'Modérer le contenu de SORTIE', + preset: 'Réponses préétablies', + placeholder: 'Contenu des réponses prédéfinies ici', + condition: 'Contenu MODÉRÉ pour INPUT et OUTPUT activé au moins une fois', + fromApi: 'Les réponses prédéfinies sont renvoyées par l\'API', + errorMessage: 'Les réponses prédéfinies ne peuvent pas être vides', + supportMarkdown: 'Prise en charge de Markdown', + }, + openaiNotConfig: { + before: 'La modération d\'OpenAI nécessite une clé API OpenAI configurée dans le', + after: 'Sorry, but you didn\'t provide a text to translate. Could you please provide the text?', + }, + }, + }, + }, + automatic: { + title: 'Orchestration automatique d\'application', + description: 'Décrivez votre scénario, Dify orchestrera une application pour vous.', + intendedAudience: 'Qui est le public cible ?', + intendedAudiencePlaceHolder: 'par exemple. Étudiant', + solveProblem: 'Quels problèmes espèrent-ils que l\'IA peut résoudre pour eux ?', + solveProblemPlaceHolder: 'par exemple, Évaluation des performances académiques', + generate: 'Générer', + audiencesRequired: 'Audiences requises', + problemRequired: 'Problème requis', + resTitle: 'Nous avons orchestré l\'application suivante pour vous.', + apply: 'Appliquez cette orchestration', + noData: 'Décrivez votre cas d\'utilisation sur la gauche, l\'aperçu de l\'orchestration s\'affichera ici.', + loading: 'Orchestration de l\'application pour vous...', + overwriteTitle: 'Remplacer la configuration existante ?', + overwriteMessage: 'L\'application de cette orchestration remplacera la configuration existante.', + }, + resetConfig: { + title: 'Confirmer la réinitialisation ?', + message: + 'Réinitialiser supprime les modifications, en restaurant la dernière configuration publiée.', + }, + errorMessage: { + nameOfKeyRequired: 'name of the key: {{key}} required', + valueOfVarRequired: 'La valeur de {{key}} ne peut pas être vide', + queryRequired: 'Le texte de la requête est requis.', + waitForResponse: + 'Veuillez attendre que la réponse au message précédent soit terminée.', + waitForBatchResponse: + 'Veuillez attendre que la réponse à la tâche en lot soit terminée.', + notSelectModel: 'Veuillez choisir un modèle', + waitForImgUpload: 'Veuillez attendre que l\'image soit téléchargée', + }, + chatSubTitle: 'Instructions', + completionSubTitle: 'Indicatif de Prompt', + promptTip: + 'Les prompts guident les réponses de l\'IA avec des instructions et des contraintes. Insérez des variables comme {{input}}. Ce prompt ne sera pas visible pour les utilisateurs.', + formattingChangedTitle: 'Formatage modifié', + formattingChangedText: + 'La modification du formatage réinitialisera la zone de débogage, êtes-vous sûr ?', + variableTitle: 'Variables', + variableTip: + 'Les utilisateurs remplissent des variables dans un formulaire, remplaçant automatiquement les variables dans le prompt.', + notSetVar: 'Les variables permettent aux utilisateurs d\'introduire des mots de prompt ou des remarques d\'ouverture lors du remplissage des formulaires. Vous pouvez essayer de saisir "{{input}}" dans les mots de prompt.', + autoAddVar: 'Des variables indéfinies référencées dans le pre-prompt, voulez-vous les ajouter dans le formulaire d\'entrée de l\'utilisateur ?', + variableTable: { + key: 'Clé Variable', + name: 'Nom du champ d\'entrée de l\'utilisateur', + optional: 'Facultatif', + type: 'Type d\'Entrée', + action: 'Actions', + typeString: 'Chaîne', + typeSelect: 'Sélectionner', + }, + varKeyError: { + canNoBeEmpty: 'La clé variable ne peut pas être vide', + tooLong: 'Variable key: {{key}} too length. Can not be longer then 30 characters', + notValid: 'Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores', + notStartWithNumber: 'Variable key: {{key}} can not start with a number', + keyAlreadyExists: 'Variable key: :{{key}} already exists', + }, + otherError: { + promptNoBeEmpty: 'Le prompt ne peut pas être vide', + historyNoBeEmpty: 'L\'historique de la conversation doit être défini dans le prompt', + queryNoBeEmpty: 'La requête doit être définie dans le prompt', + }, + variableConig: { + modalTitle: 'Paramètres du champ', + description: 'Paramètre pour la variable {{varName}}', + fieldType: 'Type de champ', + string: 'Texte Court', + paragraph: 'Paragraphe', + select: 'Sélectionner', + notSet: 'Non défini, essayez de taper {{input}} dans l\'invite de préfixe', + stringTitle: 'Options de la boîte de texte du formulaire', + maxLength: 'Longueur maximale', + options: 'Options', + addOption: 'Ajouter une option', + apiBasedVar: 'Variable basée sur l\'API', + }, + vision: { + name: 'Vision', + description: 'Enable Vision permettra au modèle de prendre des images et de répondre à des questions à leur sujet.', + settings: 'Paramètres', + visionSettings: { + title: 'Paramètres de Vision', + resolution: 'Résolution', + resolutionTooltip: `low res will allow model receive a low-res 512 x 512 version of the image, and represent the image with a budget of 65 tokens. This allows the API to return faster responses and consume fewer input tokens for use cases that do not require high detail. + \n + high res will first allows the model to see the low res image and then creates detailed crops of input images as 512px squares based on the input image size. Each of the detailed crops uses twice the token budget for a total of 129 tokens.`, + high: 'Élevé', + low: 'Faible', + uploadMethod: 'Méthode de Téléchargement', + both: 'Les deux', + localUpload: 'Téléchargement Local', + url: 'URL', + uploadLimit: 'Limite de téléchargement', + }, + }, + voice: { + name: 'Voix', + defaultDisplay: 'Voix par défaut', + description: 'Paramètres de la voix de synthèse vocale', + settings: 'Paramètres', + voiceSettings: { + title: 'Paramètres de voix', + language: 'Langue', + resolutionTooltip: 'Support de la langue pour la voix de synthèse de texte.', + voice: 'Voix', + }, + }, + openingStatement: { + title: 'Ouverture de Conversation', + add: 'Ajouter', + writeOpner: 'Écrire l\'introduction', + placeholder: 'Rédigez votre message d\'ouverture ici, vous pouvez utiliser des variables, essayez de taper {{variable}}.', + openingQuestion: 'Questions d\'ouverture', + noDataPlaceHolder: + 'Commencer la conversation avec l\'utilisateur peut aider l\'IA à établir une connexion plus proche avec eux dans les applications conversationnelles.', + varTip: 'Vous pouvez utiliser des variables, essayez de taper {{variable}}', + tooShort: 'Au moins 20 mots de l\'invite initiale sont requis pour générer des remarques d\'ouverture pour la conversation.', + notIncludeKey: 'The initial prompt does not include the variable: {{key}}. Please add it to the initial prompt.', + }, + modelConfig: { + model: 'Modèle', + setTone: 'Définir le ton des réponses', + title: 'Modèle et Paramètres', + modeType: { + chat: 'Discussion', + completion: 'Complet', + }, + }, + inputs: { + title: 'Déboguer et Aperçu', + noPrompt: 'Essayez d\'écrire une proposition dans l\'entrée pré-proposition', + userInputField: 'Champ de saisie utilisateur', + noVar: 'Remplissez la valeur de la variable, qui sera automatiquement remplacée dans le mot d\'invite chaque fois qu\'une nouvelle session est démarrée.', + chatVarTip: + 'Remplissez la valeur de la variable, qui sera automatiquement remplacée dans le mot d\'invite chaque fois qu\'une nouvelle session est démarrée', + completionVarTip: + 'Remplissez la valeur de la variable, qui sera automatiquement remplacée dans les mots de l\'invite chaque fois qu\'une question est soumise.', + previewTitle: 'Aperçu de la prompte', + queryTitle: 'Contenu de la requête', + queryPlaceholder: 'Veuillez entrer le texte de la demande.', + run: 'EXÉCUTER', + }, + result: 'Texte de sortie', + datasetConfig: { + settingTitle: 'Paramètres de récupération', + retrieveOneWay: { + title: 'Récupération N-vers-1', + description: 'En fonction de l\'intention de l\'utilisateur et des descriptions de Connaissance, l\'Agent sélectionne de manière autonome la meilleure Connaissance pour interroger. Idéal pour les applications avec une Connaissance distincte et limitée.', + }, + retrieveMultiWay: { + title: 'Récupération multi-chemins', + description: 'En fonction de l\'intention de l\'utilisateur, interroge toutes les connaissances, récupère le texte pertinent de plusieurs sources et sélectionne les meilleurs résultats correspondant à la requête de l\'utilisateur après réordonnancement. La configuration de l\'API du modèle de réordonnancement est requise', + }, + rerankModelRequired: 'Un modèle de réorganisation est nécessaire', + params: 'Paramètres', + top_k: 'Top K', + top_kTip: 'Utilisé pour filtrer les morceaux qui sont les plus similaires aux questions de l\'utilisateur. Le système ajustera également dynamiquement la valeur de Top K, selon max_tokens du modèle sélectionné.', + score_threshold: 'Seuil de Score', + score_thresholdTip: 'Utilisé pour définir le seuil de similarité pour le filtrage des morceaux.', + retrieveChangeTip: 'La modification du mode d\'indexation et du mode de récupération peut affecter les applications associées à cette Connaissance.', + }, + debugAsSingleModel: 'Déboguer comme Modèle Unique', + debugAsMultipleModel: 'Déboguer en tant que Modèles Multiples', + duplicateModel: 'Dupliquer', + publishAs: 'Publier comme', + assistantType: { + name: 'Type d\'Assistant', + chatAssistant: { + name: 'Assistant de Base', + description: 'Construisez un assistant basé sur le chat en utilisant un grand modèle de langage', + }, + agentAssistant: { + name: 'Assistant Agent', + description: 'Construisez un Agent intelligent qui peut choisir de manière autonome des outils pour accomplir les tâches', + }, + }, + agent: { + agentMode: 'Mode Agent', + agentModeDes: 'Définissez le type de mode d\'inférence pour l\'agent', + agentModeType: { + ReACT: 'RéAgir', + functionCall: 'Appel de fonction', + }, + setting: { + name: 'Paramètres de l\'Agent', + description: 'Les paramètres de l\'Assistant Agent permettent de définir le mode de l\'agent et des fonctionnalités avancées comme les prompts intégrés, uniquement disponibles dans le type Agent.', + maximumIterations: { + name: 'Nombre Maximum d\'Itérations', + description: 'Limiter le nombre d\'itérations qu\'un assistant agent peut exécuter', + }, + }, + buildInPrompt: 'Prompt Intégré', + firstPrompt: 'Première Prompte', + nextIteration: 'Prochaine Itération', + promptPlaceholder: 'Écrivez votre prompt ici', + tools: { + name: 'Outils', + description: 'L\'utilisation d\'outils peut étendre les capacités des LLM, comme la recherche sur internet ou l\'exécution de calculs scientifiques.', + enabled: 'Activé', + }, + }, +} + +export default translation diff --git a/web/i18n/fr-FR/app-log.ts b/web/i18n/fr-FR/app-log.ts new file mode 100644 index 0000000000..4ff1985869 --- /dev/null +++ b/web/i18n/fr-FR/app-log.ts @@ -0,0 +1,69 @@ +const translation = { + title: 'Journaux', + description: 'Les journaux enregistrent l\'état de fonctionnement de l\'application, y compris les entrées de l\'utilisateur et les réponses de l\'IA.', + dateTimeFormat: 'JJ/MM/AAAA hh:mm A', + table: { + header: { + time: 'Temps', + endUser: 'Utilisateur final', + input: 'Entrée', + output: 'Sortie', + summary: 'Titre', + messageCount: 'Nombre de Messages', + userRate: 'Taux d\'utilisateur', + adminRate: 'Taux Op.', + }, + pagination: { + previous: 'Précédent', + next: 'Suivant', + }, + empty: { + noChat: 'Aucune conversation pour le moment', + noOutput: 'Aucune sortie', + element: { + title: 'Quelqu\'un est là ?', + content: 'Observez et annotez les interactions entre les utilisateurs finaux et les applications IA ici pour améliorer continuellement la précision de l\'IA. Vous pouvez essayer de partager ou de tester l\'application Web', + }, + }, + }, + detail: { + time: 'Temps', + conversationId: 'ID de conversation', + promptTemplate: 'Modèle de Prompt', + promptTemplateBeforeChat: 'Modèle de Prompt Avant le Chat · En Tant que Message Système', + annotationTip: 'Améliorations Marquées par {{user}}', + timeConsuming: 'Apologies, but you haven\'t provided any text to translate. Could you please provide the text so I can help you with the translation?', + second: '"s"', + tokenCost: 'Jeton dépensé', + loading: 'chargement', + operation: { + like: 'comme', + dislike: 'déteste', + addAnnotation: 'Ajouter une amélioration', + editAnnotation: 'Amélioration de l\'édition', + annotationPlaceholder: 'Entrez la réponse attendue que vous souhaitez que l\'IA donne, qui peut être utilisée pour l\'ajustement fin du modèle et l\'amélioration continue de la qualité de génération de texte à l\'avenir.', + }, + variables: 'Variables', + uploadImages: 'Images Téléchargées', + }, + filter: { + period: { + today: 'Aujourd\'hui', + last7days: 'Les 7 Derniers Jours', + last4weeks: 'Les 4 dernières semaines', + last3months: 'Les 3 derniers mois', + last12months: 'Les 12 derniers mois', + monthToDate: 'Mois à ce jour', + quarterToDate: 'Trimestre à ce jour', + yearToDate: 'Année à ce jour', + allTime: 'Tout le temps', + }, + annotation: { + all: 'Tout', + annotated: 'Améliorations annotées ({{count}} éléments)', + not_annotated: 'Non Annoté', + }, + }, +} + +export default translation diff --git a/web/i18n/fr-FR/app-overview.ts b/web/i18n/fr-FR/app-overview.ts new file mode 100644 index 0000000000..faad0e66ff --- /dev/null +++ b/web/i18n/fr-FR/app-overview.ts @@ -0,0 +1,139 @@ +const translation = { + welcome: { + firstStepTip: 'Pour commencer,', + enterKeyTip: 'entrez votre clé API OpenAI ci-dessous', + getKeyTip: 'Obtenez votre clé API depuis le tableau de bord OpenAI', + placeholder: 'Votre clé API OpenAI (par exemple, sk-xxxx)', + }, + apiKeyInfo: { + cloud: { + trial: { + title: 'Vous utilisez le quota d\'essai de {{providerName}}.', + description: 'Le quota d\'essai est fourni pour votre utilisation de test. Avant que les appels de quota d\'essai ne soient épuisés, veuillez configurer votre propre fournisseur de modèle ou acheter un quota supplémentaire.', + }, + exhausted: { + title: 'Votre quota d\'essai a été utilisé, veuillez configurer votre APIKey.', + description: 'Votre quota d\'essai a été épuisé. Veuillez configurer votre propre fournisseur de modèle ou acheter un quota supplémentaire.', + }, + }, + selfHost: { + title: { + row1: 'Pour commencer,', + row2: 'configurez d\'abord votre fournisseur de modèle.', + }, + }, + callTimes: 'Heures d\'appel', + usedToken: 'Token utilisé', + setAPIBtn: 'Allez configurer le fournisseur de modèle', + tryCloud: 'Ou essayez la version cloud de Dify avec un devis gratuit', + }, + overview: { + title: 'Aperçu', + appInfo: { + explanation: 'WebApp IA prête à l\'emploi', + accessibleAddress: 'URL publique', + preview: 'Aperçu', + regenerate: 'Régénérer', + preUseReminder: 'Veuillez activer WebApp avant de continuer.', + settings: { + entry: 'Paramètres', + title: 'Paramètres de l\'application Web', + webName: 'Nom de l\'application Web', + webDesc: 'Description de l\'application web', + webDescTip: 'Ce texte sera affiché du côté du client, fournissant des indications de base sur comment utiliser l\'application', + webDescPlaceholder: 'Entrez la description de la WebApp', + language: 'Langue', + more: { + entry: 'Montrer plus de paramètres', + copyright: 'Droit d\'auteur', + copyRightPlaceholder: 'Entrez le nom de l\'auteur ou de l\'organisation', + privacyPolicy: 'Politique de Confidentialité', + privacyPolicyPlaceholder: 'Entrez le lien de la politique de confidentialité', + privacyPolicyTip: 'Aide les visiteurs à comprendre les données que l\'application collecte, voir la Politique de Confidentialité de Dify.', + }, + }, + embedded: { + entry: 'Intégré', + title: 'Intégrer sur le site web', + explanation: 'Choisissez la manière d\'intégrer l\'application de chat à votre site web', + iframe: 'Pour ajouter l\'application de chat n\'importe où sur votre site web, ajoutez cette iframe à votre code html.', + scripts: 'Pour ajouter une application de chat en bas à droite de votre site web, ajoutez ce code à votre html.', + chromePlugin: 'Installez l\'extension Chrome Dify Chatbot', + copied: 'Copié', + copy: 'Copier', + }, + qrcode: { + title: 'QR code à partager', + scan: 'Application de Partage de Scan', + download: 'Télécharger le Code QR', + }, + customize: { + way: 'manière', + entry: 'Personnaliser', + title: 'Personnaliser l\'WebApp IA', + explanation: 'Vous pouvez personnaliser l\'interface utilisateur de l\'application Web pour répondre à vos besoins en termes de scénario et de style.', + way1: { + name: 'Faites une fourchette du code client, modifiez-le et déployez-le sur Vercel (recommandé)', + step1: 'Faites une fourchette du code client et modifiez-le', + step1Tip: 'Cliquez ici pour bifurquer le code source dans votre compte GitHub et modifier le code', + step1Operation: 'Dify-WebClient', + step2: 'Déployer sur Vercel', + step2Tip: 'Cliquez ici pour importer le dépôt dans Vercel et déployer', + step2Operation: 'Importer le dépôt', + step3: 'Configurer les variables d\'environnement', + step3Tip: 'Ajoutez les variables d\'environnement suivantes dans Vercel', + }, + way2: { + name: 'Écrivez du code côté client pour appeler l\'API et déployez-le sur un serveur', + operation: 'Documentation', + }, + }, + }, + apiInfo: { + title: 'API du service Backend', + explanation: 'Facilement intégré dans votre application', + accessibleAddress: 'Point de terminaison du service API', + doc: 'Référence API', + }, + status: { + running: 'En service', + disable: 'Désactiver', + }, + }, + analysis: { + title: 'Analyse', + ms: 'ms', + tokenPS: 'Jeton/s', + totalMessages: { + title: 'Messages Totaux', + explanation: 'Nombre quotidien d\'interactions IA ; ingénierie/debuggage de prompt exclu.', + }, + activeUsers: { + title: 'Utilisateurs Actifs', + explanation: 'Utilisateurs uniques participant à des Q&A avec l\'IA ; l\'ingénierie/débogage de prompt exclu.', + }, + tokenUsage: { + title: 'Utilisation de Token', + explanation: 'Reflet de l\'utilisation quotidienne des jetons du modèle de langage pour l\'application, utile à des fins de contrôle des coûts.', + consumed: 'Consommé', + }, + avgSessionInteractions: { + title: 'Interactions Moyennes par Session', + explanation: 'Comptage continu de la communication utilisateur-IA ; pour les applications basées sur la conversation.', + }, + userSatisfactionRate: { + title: 'Taux de Satisfaction de l\'Utilisateur', + explanation: 'Le nombre de "j\'aime" par 1 000 messages. Cela indique la proportion de réponses dont les utilisateurs sont très satisfaits.', + }, + avgResponseTime: { + title: 'Temps de réponse moyen', + explanation: 'Temps (ms) pour que l\'IA traite/réponde; pour les applications basées sur le texte.', + }, + tps: { + title: 'Vitesse de Sortie des Tokens', + explanation: 'Mesurez la performance du LLM. Comptez la vitesse de sortie des Tokens du LLM depuis le début de la demande jusqu\'à l\'achèvement de la sortie.', + }, + }, +} + +export default translation diff --git a/web/i18n/fr-FR/app.ts b/web/i18n/fr-FR/app.ts new file mode 100644 index 0000000000..f8a6f721bc --- /dev/null +++ b/web/i18n/fr-FR/app.ts @@ -0,0 +1,54 @@ +const translation = { + createApp: 'Créer une nouvelle application', + types: { + all: 'Tout', + assistant: 'Assistant', + completion: 'Complétion', + }, + modes: { + completion: 'Générateur de Texte', + chat: 'Assistant de Base', + }, + createFromConfigFile: 'Créer une application à partir du fichier de configuration', + deleteAppConfirmTitle: 'Supprimer cette application ?', + deleteAppConfirmContent: + 'La suppression de l\'application est irréversible. Les utilisateurs ne pourront plus accéder à votre application, et toutes les configurations de prompt et les journaux seront définitivement supprimés.', + appDeleted: 'Application supprimée', + appDeleteFailed: 'Échec de la suppression de l\'application', + join: 'Rejoignez la communauté', + communityIntro: + 'Discutez avec les membres de l\'équipe, les contributeurs et les développeurs sur différents canaux.', + roadmap: 'Voir notre feuille de route', + appNamePlaceholder: 'Veuillez entrer le nom de l\'application', + newApp: { + startToCreate: 'Commençons avec votre nouvelle application', + captionName: 'Icône et nom de l\'application', + captionAppType: 'Quel type d\'application voulez-vous créer ?', + previewDemo: 'Aperçu de la démo', + chatApp: 'Assistant', + chatAppIntro: + 'Je veux construire une application basée sur le chat. Cette application utilise un format de questions-réponses, permettant plusieurs tours de conversation continue.', + agentAssistant: 'Nouvel Assistant Agent', + completeApp: 'Générateur de Texte', + completeAppIntro: + 'Je veux créer une application qui génère du texte de haute qualité basé sur des prompts, tels que la génération d\'articles, de résumés, de traductions, et plus encore.', + showTemplates: 'Je veux choisir à partir d\'un modèle', + hideTemplates: 'Revenir à la sélection de mode', + Create: 'Créer', + Cancel: 'Annuler', + nameNotEmpty: 'Le nom ne peut pas être vide', + appTemplateNotSelected: 'Veuillez sélectionner un modèle', + appTypeRequired: 'Veuillez sélectionner un type d\'application', + appCreated: 'Application créée', + appCreateFailed: 'Échec de la création de l\'application', + }, + editApp: { + startToEdit: 'Modifier l\'application', + }, + emoji: { + ok: 'D\'accord', + cancel: 'Annuler', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/billing.ts b/web/i18n/fr-FR/billing.ts new file mode 100644 index 0000000000..09c8ca43a8 --- /dev/null +++ b/web/i18n/fr-FR/billing.ts @@ -0,0 +1,115 @@ +const translation = { + currentPlan: 'Plan Actuel', + upgradeBtn: { + plain: 'Mettre à jour le plan', + encourage: 'Mettre à niveau maintenant', + encourageShort: 'Mise à niveau', + }, + viewBilling: 'Gérer la facturation et les abonnements', + buyPermissionDeniedTip: 'Veuillez contacter votre administrateur d\'entreprise pour vous abonner', + plansCommon: { + title: 'Choisissez un plan qui vous convient', + yearlyTip: 'Obtenez 2 mois gratuitement en vous abonnant annuellement !', + mostPopular: 'Le Plus Populaire', + planRange: { + monthly: 'Mensuel', + yearly: 'Annuel', + }, + month: 'mois', + year: 'année', + save: 'Enregistrer', + free: 'Gratuit', + currentPlan: 'Plan Actuel', + contractSales: 'Contactez les ventes', + contractOwner: 'Contacter le chef d\'équipe', + startForFree: 'Commencez gratuitement', + getStartedWith: 'Commencez avec', + contactSales: 'Contacter les ventes', + talkToSales: 'Parlez aux Ventes', + modelProviders: 'Fournisseurs de Modèles', + teamMembers: 'Membres de l\'équipe', + buildApps: 'Construire des Applications', + vectorSpace: 'Espace Vectoriel', + vectorSpaceBillingTooltip: 'Chaque 1MB peut stocker environ 1,2 million de caractères de données vectorisées (estimé en utilisant les embeddings OpenAI, varie selon les modèles).', + vectorSpaceTooltip: 'L\'espace vectoriel est le système de mémoire à long terme nécessaire pour que les LLMs comprennent vos données.', + documentsUploadQuota: 'Quota de téléchargement de documents', + documentProcessingPriority: 'Priorité de Traitement de Document', + documentProcessingPriorityTip: 'Pour une priorité de traitement de documents plus élevée, veuillez mettre à niveau votre plan.', + documentProcessingPriorityUpgrade: 'Traitez plus de données avec une précision plus élevée à des vitesses plus rapides.', + priority: { + 'standard': 'Standard', + 'priority': 'Priorité', + 'top-priority': 'Priorité Maximale', + }, + logsHistory: 'Historique des logs', + customTools: 'Outils personnalisés', + unavailable: 'Indisponible', + days: 'jours', + unlimited: 'Illimité', + support: 'Assistance', + supportItems: { + communityForums: 'Forums communautaires', + emailSupport: 'Support par email', + priorityEmail: 'Support prioritaire par email et chat', + logoChange: 'Changement de logo', + SSOAuthentication: 'Authentification SSO', + personalizedSupport: 'Soutien personnalisé', + dedicatedAPISupport: 'Support dédié pour l\'API', + customIntegration: 'Intégration personnalisée et support', + ragAPIRequest: 'Requêtes API RAG', + bulkUpload: 'Téléchargement en masse de documents', + agentMode: 'Mode Agent', + workflow: 'Flux de travail', + }, + comingSoon: 'Bientôt disponible', + member: 'Membre', + memberAfter: 'Membre', + messageRequest: { + title: 'Crédits de message', + tooltip: 'Quotas d\'invocation de messages pour divers plans utilisant les modèles OpenAI (sauf gpt4). Les messages dépassant la limite utiliseront votre clé API OpenAI.', + }, + annotatedResponse: { + title: 'Limites de quota d\'annotation', + tooltip: 'L\'édition manuelle et l\'annotation des réponses fournissent des capacités de réponse aux questions de haute qualité personnalisables pour les applications. (Applicable uniquement dans les applications de chat)', + }, + ragAPIRequestTooltip: 'Fait référence au nombre d\'appels API invoquant uniquement les capacités de traitement de la base de connaissances de Dify.', + receiptInfo: 'Seuls le propriétaire de l\'équipe et l\'administrateur de l\'équipe peuvent s\'abonner et consulter les informations de facturation', + }, + plans: { + sandbox: { + name: 'Bac à sable', + description: '200 essais gratuits de GPT', + includesTitle: 'Inclus :', + }, + professional: { + name: 'Professionnel', + description: 'Pour les individus et les petites équipes afin de débloquer plus de puissance à un prix abordable.', + includesTitle: 'Tout ce qui est dans le plan gratuit, plus :', + }, + team: { + name: 'Équipe', + description: 'Collaborez sans limites et profitez d\'une performance de premier ordre.', + includesTitle: 'Tout ce qui est inclus dans le plan Professionnel, plus :', + }, + enterprise: { + name: 'Entreprise', + description: 'Obtenez toutes les capacités et le support pour les systèmes à grande échelle et critiques pour la mission.', + includesTitle: 'Tout ce qui est inclus dans le plan Équipe, plus :', + }, + }, + vectorSpace: { + fullTip: 'L\'espace vectoriel est plein.', + fullSolution: 'Mettez à niveau votre plan pour obtenir plus d\'espace.', + }, + apps: { + fullTipLine1: 'Mettez à jour votre plan pour', + fullTipLine2: 'construire plus d\'applications.', + }, + annotatedResponse: { + fullTipLine1: 'Mettez à niveau votre plan pour', + fullTipLine2: 'annotez plus de conversations.', + quotaTitle: 'Quota de Réponse d\'Annotation', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts new file mode 100644 index 0000000000..fdbfd5be7e --- /dev/null +++ b/web/i18n/fr-FR/common.ts @@ -0,0 +1,505 @@ +const translation = { + api: { + success: 'Succès', + actionSuccess: 'Action réussie', + saved: 'Sauvegardé', + create: 'Créé', + remove: 'Supprimé', + }, + operation: { + create: 'Créer', + confirm: 'Confirmer', + cancel: 'Annuler', + clear: 'Clair', + save: 'Enregistrer', + edit: 'Modifier', + add: 'Ajouter', + added: 'Ajouté', + refresh: 'Redémarrer', + reset: 'Réinitialiser', + search: 'Recherche', + change: 'Changer', + remove: 'Supprimer', + send: 'Envoyer', + copy: 'Copier', + lineBreak: 'Saut de ligne', + sure: 'Je suis sûr', + download: 'Télécharger', + delete: 'Supprimer', + settings: 'Paramètres', + setup: 'Configuration', + getForFree: 'Obtenez gratuitement', + reload: 'Recharger', + ok: 'D\'accord', + log: 'Journal', + learnMore: 'En savoir plus', + params: 'Paramètres', + }, + placeholder: { + input: 'Veuillez entrer', + select: 'Veuillez sélectionner', + }, + voice: { + language: { + zhHans: 'Chinois', + enUS: 'Anglais', + deDE: 'Allemand', + frFR: 'Français', + esES: 'Espagnol', + itIT: 'Italien', + thTH: 'Thaï.', + idID: 'Indonésien', + jaJP: 'Japonais', + koKR: 'Coréen', + ptBR: 'Portugais', + ruRU: 'Russe', + ukUA: 'Ukrainien', + }, + }, + unit: { + char: 'caractères', + }, + actionMsg: { + noModification: 'Aucune modification pour le moment.', + modifiedSuccessfully: 'Modifié avec succès', + modifiedUnsuccessfully: 'Modifié sans succès', + copySuccessfully: 'Copié avec succès', + paySucceeded: 'Paiement réussi', + payCancelled: 'Paiement annulé', + generatedSuccessfully: 'Généré avec succès', + generatedUnsuccessfully: 'Généré sans succès', + }, + model: { + params: { + temperature: 'Température', + temperatureTip: + 'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.', + top_p: 'Haut P', + top_pTip: + 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', + presence_penalty: 'Pénalité de présence', + presence_penaltyTip: + 'Combien pénaliser les nouveaux tokens en fonction de leur apparition dans le texte jusqu\'à présent. Augmente la probabilité du modèle de parler de nouveaux sujets.', + frequency_penalty: 'Pénalité de fréquence', + frequency_penaltyTip: + 'Combien pénaliser les nouveaux tokens en fonction de leur fréquence existante dans le texte jusqu\'à présent. Réduit la probabilité du modèle de répéter la même ligne mot pour mot.', + max_tokens: 'Max jeton', + max_tokensTip: + 'Utilisé pour limiter la longueur maximale de la réponse, en jetons. \nDes valeurs plus grandes peuvent limiter l\'espace restant pour les mots de prompt, les journaux de chat, et la Connaissance. \nIl est recommandé de le régler en dessous des', + maxTokenSettingTip: 'Votre réglage de max token est élevé, limitant potentiellement l\'espace pour les prompts, les requêtes et les données. Envisagez de le définir en dessous de 2/3.', + setToCurrentModelMaxTokenTip: 'Le max token est mis à jour à 80% du max token du modèle actuel {{maxToken}}.', + stop_sequences: 'Séquences d\'arrêt', + stop_sequencesTip: 'Jusqu\'à quatre séquences où l\'API arrêtera de générer d\'autres tokens. Le texte renvoyé ne contiendra pas la séquence d\'arrêt.', + stop_sequencesPlaceholder: 'Entrez la séquence et appuyez sur Tab', + }, + tone: { + Creative: 'Créatif', + Balanced: 'Équilibré', + Precise: 'Précis', + Custom: 'Personnalisé', + }, + addMoreModel: 'Allez dans les paramètres pour ajouter plus de modèles', + }, + menus: { + status: 'bêta', + explore: 'Explorer', + apps: 'Studio', + plugins: 'Plugins', + pluginsTips: 'Intégrez des plugins tiers ou créez des AI-Plugins compatibles avec ChatGPT.', + datasets: 'Connaissance', + datasetsTips: 'COMING SOON: Import your own text data or write data in real-time via Webhook for LLM context enhancement.', + newApp: 'Nouvelle Application', + newDataset: 'Créer des Connaissances', + tools: 'Outils', + }, + userProfile: { + settings: 'Paramètres', + workspace: 'Espace de travail', + createWorkspace: 'Créer un Espace de Travail', + helpCenter: 'Aide', + roadmapAndFeedback: 'Feuille de route & Retour d\'information', + community: 'Communauté', + about: 'À propos', + logout: 'Se déconnecter', + }, + settings: { + accountGroup: 'COMPTE', + workplaceGroup: 'ESPACE DE TRAVAIL', + account: 'Mon compte', + members: 'Membres', + billing: 'Facturation', + integrations: 'Intégrations', + language: 'Langue', + provider: 'Fournisseur de Modèle', + dataSource: 'Source de Données', + plugin: 'Plugins', + apiBasedExtension: 'Extension API', + }, + account: { + avatar: 'Avatar', + name: 'Nom', + email: 'Courriel', + password: 'Mot de passe', + passwordTip: 'Vous pouvez définir un mot de passe permanent si vous ne souhaitez pas utiliser des codes de connexion temporaires.', + setPassword: 'Définir un mot de passe', + resetPassword: 'Réinitialiser le mot de passe', + currentPassword: 'Mot de passe actuel', + newPassword: 'Nouveau mot de passe', + confirmPassword: 'Confirmer le mot de passe', + notEqual: 'Les deux mots de passe sont différents.', + langGeniusAccount: 'Compte Dify', + langGeniusAccountTip: 'Votre compte Dify et les données utilisateur associées.', + editName: 'Modifier le nom', + showAppLength: 'Afficher {{length}} applications', + }, + members: { + team: 'Équipe', + invite: 'Ajouter', + name: 'NOM', + lastActive: 'DERNIÈRE ACTIVITÉ', + role: 'RÔLES', + pending: 'En attente...', + owner: 'Propriétaire', + admin: 'Administrateur', + adminTip: 'Peut construire des applications & gérer les paramètres de l\'équipe', + normal: 'Normal', + normalTip: 'Peut seulement utiliser des applications, ne peut pas construire des applications', + inviteTeamMember: 'Ajouter un membre de l\'équipe', + inviteTeamMemberTip: 'Ils peuvent accéder directement à vos données d\'équipe après s\'être connectés.', + email: 'Courrier électronique', + emailInvalid: 'Format de courriel invalide', + emailPlaceholder: 'Veuillez entrer des emails', + sendInvite: 'Envoyer une invitation', + invitedAsRole: 'Invité en tant qu\'utilisateur {{role}}', + invitationSent: 'Invitation envoyée', + invitationSentTip: 'Invitation envoyée, et ils peuvent se connecter à Dify pour accéder aux données de votre équipe.', + invitationLink: 'Lien d\'invitation', + failedinvitationEmails: 'Les utilisateurs ci-dessous n\'ont pas été invités avec succès', + ok: 'D\'accord', + removeFromTeam: 'Retirer de l\'équipe', + removeFromTeamTip: 'Supprimera l\'accès de l\'équipe', + setAdmin: 'Définir comme administrateur', + setMember: 'Définir en tant que membre ordinaire', + disinvite: 'Annuler l\'invitation', + deleteMember: 'Supprimer Membre', + you: '(Vous)', + }, + integrations: { + connected: 'Connecté', + google: 'Google', + googleAccount: 'Connectez-vous avec un compte Google', + github: 'GitHub', + githubAccount: 'Connectez-vous avec un compte GitHub', + connect: 'Connecter', + }, + language: { + displayLanguage: 'Langue d\'affichage', + timezone: 'Fuseau horaire', + }, + provider: { + apiKey: 'Clé API', + enterYourKey: 'Entrez votre clé API ici', + invalidKey: 'Clé API OpenAI invalide', + validatedError: 'Validation failed: ', + validating: 'Validation de la clé...', + saveFailed: 'La sauvegarde de la clé API a échoué', + apiKeyExceedBill: 'Cette clé API n\'a pas de quota disponible, veuillez lire', + addKey: 'Ajouter une clé', + comingSoon: 'Bientôt disponible', + editKey: 'Modifier', + invalidApiKey: 'Clé API invalide', + azure: { + apiBase: 'Base de l\'API', + apiBasePlaceholder: 'L\'URL de base de l\'API de votre point de terminaison Azure OpenAI.', + apiKey: 'Clé API', + apiKeyPlaceholder: 'Entrez votre clé API ici', + helpTip: 'Apprenez le service OpenAI Azure', + }, + openaiHosted: { + openaiHosted: 'OpenAI Hébergé', + onTrial: 'EN ESSAI', + exhausted: 'QUOTA ÉPUISÉ', + desc: 'Le service d\'hébergement OpenAI fourni par Dify vous permet d\'utiliser des modèles tels que GPT-3.5. Avant que votre quota d\'essai ne soit épuisé, vous devez configurer d\'autres fournisseurs de modèles.', + callTimes: 'Temps d\'appel', + usedUp: 'Quota d\'essai épuisé. Ajoutez votre propre fournisseur de modèle.', + useYourModel: 'Utilise actuellement son propre fournisseur de modèle.', + close: 'Fermer', + }, + anthropicHosted: { + anthropicHosted: 'Anthropic Claude', + onTrial: 'EN ESSAI', + exhausted: 'QUOTA ÉPUISÉ', + desc: 'Modèle puissant, qui excelle dans une large gamme de tâches allant du dialogue sophistiqué et de la génération de contenu créatif à l\'instruction détaillée.', + callTimes: 'Temps d\'appel', + usedUp: 'Quota d\'essai épuisé. Ajoutez votre propre fournisseur de modèle.', + useYourModel: 'Utilise actuellement son propre fournisseur de modèle.', + close: 'Fermer', + }, + anthropic: { + using: 'La capacité d\'embedding est utilisée', + enableTip: 'Pour activer le modèle Anthropic, vous devez d\'abord vous lier à OpenAI ou au service Azure OpenAI.', + notEnabled: 'Non activé', + keyFrom: 'Obtenez votre clé API de chez Anthropic', + }, + encrypted: { + front: 'Votre clé API sera chiffrée et stockée en utilisant', + back: 'technologie.', + }, + }, + modelProvider: { + notConfigured: 'Le modèle du système n\'a pas encore été entièrement configuré, et certaines fonctions peuvent être indisponibles.', + systemModelSettings: 'Paramètres du Modèle Système', + systemModelSettingsLink: 'Pourquoi est-il nécessaire de mettre en place un modèle de système ?', + selectModel: 'Sélectionnez votre modèle', + setupModelFirst: 'Veuillez d\'abord configurer votre modèle', + systemReasoningModel: { + key: 'Modèle de Raisonnement du Système', + tip: 'Définissez le modèle d\'inférence par défaut à utiliser pour la création d\'applications, ainsi que des fonctionnalités telles que la génération de noms de dialogue et la suggestion de la prochaine question utiliseront également le modèle d\'inférence par défaut.', + }, + embeddingModel: { + key: 'Modèle d\'Embedding', + tip: 'Définissez le modèle par défaut pour le traitement d\'incorporation de documents de la Connaissance, à la fois la récupération et l\'importation de la Connaissance utilisent ce modèle d\'Embedding pour le traitement de vectorisation. Si vous changez de modèle, la dimension du vecteur entre la connaissance importée et la question ne sera pas cohérente, ce qui entraînera un échec de la recherche. Pour éviter les échecs de recherche, veuillez ne pas changer de modèle à volonté.', + required: 'Le modèle d\'embedding est requis', + }, + speechToTextModel: { + key: 'Modèle de Texte-à-Parole', + tip: 'Définissez le modèle par défaut pour l\'entrée de texte par la parole dans la conversation.', + }, + ttsModel: { + key: 'Modèle de Texte-à-Parole', + tip: 'Définissez le modèle par défaut pour l\'entrée de texte à la parole dans une conversation.', + }, + rerankModel: { + key: 'Modèle de Réorganisation', + tip: 'Le modèle de réorganisation réorganisera la liste des documents candidats en fonction de la correspondance sémantique avec la requête de l\'utilisateur, améliorant ainsi les résultats du classement sémantique.', + }, + quota: 'Quota', + searchModel: 'Modèle de recherche', + noModelFound: 'Aucun modèle trouvé pour {{model}}', + models: 'Modèles', + showMoreModelProvider: 'Montrer plus de fournisseur de modèle', + selector: { + tip: 'Ce modèle a été supprimé. Veuillez ajouter un modèle ou sélectionner un autre modèle.', + emptyTip: 'Aucun modèle disponible', + emptySetting: 'Veuillez aller dans les paramètres pour configurer', + rerankTip: 'Veuillez configurer le modèle Rerank', + }, + card: { + quota: 'QUOTA', + onTrial: 'En Essai', + paid: 'Payé', + quotaExhausted: 'Quota épuisé', + callTimes: 'Temps d\'appel', + tokens: 'Jetons', + buyQuota: 'Acheter Quota', + priorityUse: 'Utilisation prioritaire', + removeKey: 'Supprimer la clé API', + tip: 'La priorité sera donnée au quota payant. Le quota d\'essai sera utilisé après épuisement du quota payant.', + }, + item: { + deleteDesc: '{{modelName}} sont utilisés comme modèles de raisonnement système. Certaines fonctions ne seront pas disponibles après la suppression. Veuillez confirmer.', + freeQuota: 'QUOTA GRATUIT', + }, + addApiKey: 'Ajoutez votre clé API', + invalidApiKey: 'Clé API invalide', + encrypted: { + front: 'Votre clé API sera cryptée et stockée en utilisant', + back: 'technologie.', + }, + freeQuota: { + howToEarn: 'Comment gagner', + }, + addMoreModelProvider: 'AJOUTER PLUS DE FOURNISSEUR DE MODÈLE', + addModel: 'Ajouter un modèle', + modelsNum: '{{num}} Modèles', + showModels: 'Montrer les modèles', + showModelsNum: 'Afficher {{num}} Modèles', + collapse: 'Effondrer', + config: 'Configuration', + modelAndParameters: 'Modèle et Paramètres', + model: 'Modèle', + featureSupported: '{{feature}} pris en charge', + callTimes: 'Temps d\'appel', + credits: 'Crédits de Messages', + buyQuota: 'Acheter Quota', + getFreeTokens: 'Obtenez des Tokens gratuits', + priorityUsing: 'Prioriser l\'utilisation', + deprecated: 'Obsolète', + confirmDelete: 'confirmer la suppression?', + quotaTip: 'Tokens gratuits restants disponibles', + loadPresets: 'Charger les Présents', + parameters: 'PARAMÈTRES', + }, + dataSource: { + add: 'Ajouter une source de données', + connect: 'Connecter', + notion: { + title: 'Notion', + description: 'Utiliser Notion comme source de données pour la Connaissance.', + connectedWorkspace: 'Espace de travail connecté', + addWorkspace: 'Ajouter un espace de travail', + connected: 'Connecté', + disconnected: 'Déconnecté', + changeAuthorizedPages: 'Modifier les pages autorisées', + pagesAuthorized: 'Pages autorisées', + sync: 'Synchronisation', + remove: 'Supprimer', + selector: { + pageSelected: 'Pages Sélectionnées', + searchPages: 'Rechercher des pages...', + noSearchResult: 'Aucun résultat de recherche', + addPages: 'Ajouter des pages', + preview: 'APERÇU', + }, + }, + }, + plugin: { + serpapi: { + apiKey: 'Clé API', + apiKeyPlaceholder: 'Entrez votre clé API', + keyFrom: 'Obtenez votre clé SerpAPI depuis la page de compte SerpAPI', + }, + }, + apiBasedExtension: { + title: 'Les extensions API fournissent une gestion centralisée des API, simplifiant la configuration pour une utilisation facile à travers les applications de Dify.', + link: 'Apprenez comment développer votre propre Extension API.', + linkUrl: 'https://docs.dify.ai/fonctionnalites/extension/extension_basee_sur_api', + add: 'Ajouter l\'extension API', + selector: { + title: 'Extension de l\'API', + placeholder: 'Veuillez sélectionner l\'extension API', + manage: 'Gérer l\'extension API', + }, + modal: { + title: 'Ajouter une extension API', + editTitle: 'Modifier l\'extension API', + name: { + title: 'Nom', + placeholder: 'Veuillez entrer le nom', + }, + apiEndpoint: { + title: 'Point de terminaison API', + placeholder: 'Veuillez entrer le point de terminaison de l\'API', + }, + apiKey: { + title: 'clé API', + placeholder: 'Veuillez entrer la clé API', + lengthError: 'La longueur de la clé API ne peut pas être inférieure à 5 caractères', + }, + }, + type: 'Tapez', + }, + about: { + changeLog: 'Journal des modifications', + updateNow: 'Mettre à jour maintenant', + nowAvailable: 'Dify {{version}} est maintenant disponible.', + latestAvailable: 'Dify {{version}} est la dernière version disponible.', + }, + appMenus: { + overview: 'Aperçu', + promptEng: 'Orchestrer', + apiAccess: 'Accès API', + logAndAnn: 'Journaux & Annonces.', + }, + environment: { + testing: 'TESTER', + development: 'DÉVELOPPEMENT', + }, + appModes: { + completionApp: 'Générateur de Texte', + chatApp: 'Appli de Chat', + }, + datasetMenus: { + documents: 'Documents', + hitTesting: 'Test de Récupération', + settings: 'Paramètres', + emptyTip: 'La Connaissance n\'a pas été associée, veuillez aller à l\'application ou au plug-in pour compléter l\'association.', + viewDoc: 'Voir la documentation', + relatedApp: 'applications liées', + }, + voiceInput: { + speaking: 'Parle maintenant...', + converting: 'Conversion en texte...', + notAllow: 'microphone non autorisé', + }, + modelName: { + 'gpt-3.5-turbo': 'GPT-3.5-Turbo', + 'gpt-3.5-turbo-16k': 'GPT-3.5-Turbo-16K', + 'gpt-4': 'GPT-4', + 'gpt-4-32k': 'GPT-4-32K', + 'text-davinci-003': 'Texte-Davinci-003', + 'text-embedding-ada-002': 'Texte-Intégration-Ada-002', + 'whisper-1': 'Whisper-1', + 'claude-instant-1': 'Claude-Instant', + 'claude-2': 'Claude-2', + }, + chat: { + renameConversation: 'Renommer la conversation', + conversationName: 'Nom de la conversation', + conversationNamePlaceholder: 'Veuillez entrer le nom de la conversation', + conversationNameCanNotEmpty: 'Nom de la conversation requis', + citation: { + title: 'CITATIONS', + linkToDataset: 'Lien vers la Connaissance', + characters: 'Personnages :', + hitCount: 'Nombre de récupérations :', + vectorHash: 'Hachage vectoriel:', + hitScore: 'Score de Récupération:', + }, + }, + promptEditor: { + placeholder: 'Écrivez votre mot d\'invite ici, entrez \'{\' pour insérer une variable, entrez \'/\' pour insérer un bloc de contenu d\'invite', + context: { + item: { + title: 'Contexte', + desc: 'Insérez le modèle de contexte', + }, + modal: { + title: '{{num}} Connaissance en Contexte', + add: 'Ajouter Contexte', + footer: 'Vous pouvez gérer les contextes dans la section Contexte ci-dessous.', + }, + }, + history: { + item: { + title: 'Historique des conversations', + desc: 'Insérer le modèle de message historique', + }, + modal: { + title: 'EXEMPLE', + user: 'Bonjour', + assistant: 'Bonjour ! Comment puis-je vous aider aujourd\'hui ?', + edit: 'Modifier les Noms des Rôles de Conversation', + }, + }, + variable: { + item: { + title: 'Variables & Outils Externes', + desc: 'Insérer des Variables & Outils Externes', + }, + modal: { + add: 'Nouvelle variable', + addTool: 'Nouvel outil', + }, + }, + query: { + item: { + title: 'Requête', + desc: 'Insérez le modèle de requête utilisateur', + }, + }, + existed: 'Existe déjà dans le prompt', + }, + imageUploader: { + uploadFromComputer: 'Télécharger depuis l\'ordinateur', + uploadFromComputerReadError: 'La lecture de l\'image a échoué, veuillez réessayer.', + uploadFromComputerUploadError: 'Le téléchargement de l\'image a échoué, veuillez télécharger à nouveau.', + uploadFromComputerLimit: 'Le téléchargement d\'images ne peut pas dépasser {{size}} MB', + pasteImageLink: 'Collez le lien de l\'image', + pasteImageLinkInputPlaceholder: 'Collez le lien de l\'image ici', + pasteImageLinkInvalid: 'Lien d\'image invalide', + imageUpload: 'Téléchargement d\'image', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/custom.ts b/web/i18n/fr-FR/custom.ts new file mode 100644 index 0000000000..c0c651cdb7 --- /dev/null +++ b/web/i18n/fr-FR/custom.ts @@ -0,0 +1,30 @@ +const translation = { + custom: 'Personnalisation', + upgradeTip: { + prefix: 'Mettez à niveau votre plan pour', + suffix: 'personnalisez votre marque.', + }, + webapp: { + title: 'Personnalisez la marque WebApp', + removeBrand: 'Supprimer Propulsé par Dify', + changeLogo: 'Changer Propulsé par l\'Image de Marque', + changeLogoTip: 'Format SVG ou PNG avec une taille minimum de 40x40px', + }, + app: { + title: 'Personnaliser la marque de l\'en-tête de l\'application', + changeLogoTip: 'Format SVG ou PNG avec une taille minimale de 80x80px', + }, + upload: 'Télécharger', + uploading: 'Téléchargement', + uploadedFail: 'Le téléchargement de l\'image a échoué, veuillez la télécharger à nouveau.', + change: 'Changer', + apply: 'Appliquer', + restore: 'Rétablir les paramètres par défaut', + customize: { + contactUs: 'Contactez-nous', + prefix: 'Pour personnaliser le logo de la marque dans l\'application, s\'il vous plaît', + suffix: 'pour passer à l\'édition Enterprise.', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/dataset-creation.ts b/web/i18n/fr-FR/dataset-creation.ts new file mode 100644 index 0000000000..da3ac8d476 --- /dev/null +++ b/web/i18n/fr-FR/dataset-creation.ts @@ -0,0 +1,130 @@ +const translation = { + steps: { + header: { + creation: 'Créer des Connaissances', + update: 'Ajouter des données', + }, + one: 'Choisissez la source de données', + two: 'Prétraitement et Nettoyage du Texte', + three: 'Exécutez et terminez', + }, + error: { + unavailable: 'Cette connaissance n\'est pas disponible', + }, + stepOne: { + filePreview: 'Aperçu du fichier', + pagePreview: 'Aperçu de la page', + dataSourceType: { + file: 'Importer à partir d\'un fichier texte', + notion: 'Synchroniser depuis Notion', + web: 'Synchroniser depuis le site web', + }, + uploader: { + title: 'Télécharger le fichier texte', + button: 'Glisser et déposer le fichier, ou', + browse: 'Parcourir', + tip: 'Prend en charge {{supportTypes}}. Max {{size}}MB chacun.', + validation: { + typeError: 'Type de fichier non pris en charge', + size: 'Fichier trop volumineux. Le maximum est de {{size}}MB', + count: 'Plusieurs fichiers non pris en charge', + filesNumber: 'Vous avez atteint la limite de téléchargement par lot de {{filesNumber}}.', + }, + cancel: 'Annuler', + change: 'Changer', + failed: 'Le téléchargement a échoué', + }, + notionSyncTitle: 'Notion n\'est pas connecté', + notionSyncTip: 'Pour synchroniser avec Notion, une connexion à Notion doit d\'abord être établie.', + connect: 'Aller à connecter', + button: 'suivant', + emptyDatasetCreation: 'Je veux créer un Savoir vide', + modal: { + title: 'Créer une Connaissance vide', + tip: 'Une Connaissance vide ne contiendra aucun document, et vous pouvez télécharger des documents à tout moment.', + input: 'Nom de la connaissance', + placeholder: 'Veuillez entrer', + nameNotEmpty: 'Le nom ne peut pas être vide', + nameLengthInvaild: 'Le nom doit comporter entre 1 et 40 caractères.', + cancelButton: 'Annuler', + confirmButton: 'Créer', + failed: 'Création échouée', + }, + }, + stepTwo: { + segmentation: 'Paramètres de bloc', + auto: 'Automatique', + autoDescription: 'Définir automatiquement les règles de découpage et de prétraitement. Il est recommandé aux utilisateurs non familiers de sélectionner ceci.', + custom: 'Personnalisé', + customDescription: 'Personnalisez les règles de morceaux, la longueur des morceaux et les règles de prétraitement, etc.', + separator: 'Identifiant de segment', + separatorPlaceholder: 'Par exemple, nouvelle ligne (\\\\n) ou séparateur spécial (tel que "***")', + maxLength: 'Longueur maximale du morceau', + overlap: 'Chevauchement de morceaux', + overlapTip: 'La définition d\'un chevauchement de morceaux peut maintenir la pertinence sémantique entre eux, améliorant ainsi l\'effet de récupération. Il est recommandé de définir 10%-25% de la taille maximale du morceau.', + overlapCheck: 'le chevauchement de morceaux ne doit pas être plus grand que la longueur maximale de morceau', + rules: 'Règles de prétraitement du texte', + removeExtraSpaces: 'Remplacer les espaces consécutifs, les sauts de ligne et les tabulations', + removeUrlEmails: 'Supprimez toutes les URL et adresses e-mail', + removeStopwords: 'Supprimez les mots vides tels que "a", "an", "the"', + preview: 'Confirmer & Aperçu', + reset: 'Réinitialiser', + indexMode: 'Mode d\'index', + qualified: 'Haute Qualité', + recommend: 'Recommander', + qualifiedTip: 'Appelez l\'interface d\'embedding système par défaut pour le traitement afin de fournir une précision plus élevée lorsque les utilisateurs font une requête.', + warning: 'Veuillez d\'abord configurer la clé API du fournisseur de modèle.', + click: 'Aller aux paramètres', + economical: 'Économique', + economicalTip: 'Utilisez des moteurs vectoriels hors ligne, des index de mots-clés, etc. pour réduire la précision sans dépenser de jetons', + QATitle: 'Segmentation en format Question & Réponse', + QATip: 'Activer cette option consommera plus de jetons', + QALanguage: 'Segmenter en utilisant', + emstimateCost: 'Estimation', + emstimateSegment: 'Morceaux estimés', + segmentCount: 'morceaux', + calculating: 'En calcul...', + fileSource: 'Prétraiter les documents', + notionSource: 'Prétraiter les pages', + other: 'et autres', + fileUnit: 'fichiers', + notionUnit: 'pages', + previousStep: 'Étape précédente', + nextStep: 'Enregistrer & Traiter', + save: 'Enregistrer & Traiter', + cancel: 'Annuler', + sideTipTitle: 'Pourquoi découper et prétraiter ?', + sideTipP1: 'Lors du traitement des données textuelles, le découpage et le nettoyage sont deux étapes importantes de la prétraitement.', + sideTipP2: 'La segmentation divise les longs textes en paragraphes afin que les modèles puissent mieux comprendre. Cela améliore la qualité et la pertinence des résultats du modèle.', + sideTipP3: 'Le nettoyage élimine les caractères et les formats inutiles, rendant le Savoir plus propre et plus facile à analyser.', + sideTipP4: 'Un bon découpage et nettoyage améliorent les performances du modèle, fournissant des résultats plus précis et précieux.', + previewTitle: 'Aperçu', + previewTitleButton: 'Aperçu', + previewButton: 'Passage au format Q&R', + previewSwitchTipStart: 'L\'aperçu actuel du morceau est en format texte, passer à un aperçu en format de questions-réponses va', + previewSwitchTipEnd: 'consommer des tokens supplémentaires', + characters: 'personnages', + indexSettedTip: 'Pour changer la méthode d\'index, veuillez aller à la', + retrivalSettedTip: 'Pour changer la méthode d\'index, veuillez aller à la', + datasetSettingLink: 'Paramètres de connaissance.', + }, + stepThree: { + creationTitle: '🎉 Connaissance créée', + creationContent: 'Nous avons automatiquement nommé le Savoir, vous pouvez le modifier à tout moment', + label: 'Nom de la connaissance', + additionTitle: '🎉 Document téléchargé', + additionP1: 'Le document a été téléchargé dans la Connaissance', + additionP2: ', vous pouvez le trouver dans la liste des documents de la Connaissance.', + stop: 'Arrêter le traitement', + resume: 'Reprendre le traitement', + navTo: 'Aller au document', + sideTipTitle: 'Qu\'est-ce qui suit ?', + sideTipContent: 'Après l\'indexation du document, la Connaissance peut être intégrée dans l\'application en tant que contexte, vous pouvez trouver le paramètre de contexte sur la page d\'orchestration de prompt. Vous pouvez également le créer en tant que plugin d\'indexation ChatGPT ind', + modelTitle: 'Êtes-vous sûr de vouloir arrêter l\'embedding ?', + modelContent: 'Si vous devez reprendre le traitement plus tard, vous continuerez à partir de l\'endroit où vous vous êtes arrêté.', + modelButtonConfirm: 'Confirmer', + modelButtonCancel: 'Annuler', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/dataset-documents.ts b/web/i18n/fr-FR/dataset-documents.ts new file mode 100644 index 0000000000..c6b0fca1df --- /dev/null +++ b/web/i18n/fr-FR/dataset-documents.ts @@ -0,0 +1,349 @@ +const translation = { + list: { + title: 'Documents', + desc: 'Tous les fichiers de la Connaissance sont affichés ici, et l\'ensemble de la Connaissance peut être lié aux citations Dify ou indexé via le plugin Chat.', + addFile: 'ajouter un fichier', + addPages: 'Ajouter des Pages', + table: { + header: { + fileName: 'NOM DU FICHIER', + words: 'MOTS', + hitCount: 'NOMBRE DE RÉCUPÉRATIONS', + uploadTime: 'TEMPS DE TÉLÉCHARGEMENT', + status: 'STATUT', + action: 'ACTION', + }, + }, + action: { + uploadFile: 'Télécharger un nouveau fichier', + settings: 'Paramètres de segment', + addButton: 'Ajouter un morceau', + add: 'Ajouter un morceau', + batchAdd: 'Ajout en lot', + archive: 'Archive', + unarchive: 'Décompresser', + delete: 'Supprimer', + enableWarning: 'Le fichier archivé ne peut pas être activé', + sync: 'Synchroniser', + }, + index: { + enable: 'Activer', + disable: 'Désactiver', + all: 'Tout', + enableTip: 'Le fichier peut être indexé', + disableTip: 'Le fichier ne peut pas être indexé', + }, + status: { + queuing: 'Mise en file d\'attente', + indexing: 'Indexation', + paused: 'En pause', + error: 'Erreur', + available: 'Disponible', + enabled: 'Activé', + disabled: 'Désactivé', + archived: 'Archivé', + }, + empty: { + title: 'Il n\'y a pas encore de documentation', + upload: { + tip: 'Vous pouvez télécharger des fichiers, synchroniser à partir du site web, ou à partir d\'applications web comme Notion, GitHub, etc.', + }, + sync: { + tip: 'Dify téléchargera périodiquement des fichiers de votre Notion et terminera le traitement.', + }, + }, + delete: { + title: 'Êtes-vous sûr de vouloir supprimer ?', + content: 'Si vous avez besoin de reprendre le traitement plus tard, vous continuerez à partir de l\'endroit où vous vous êtes arrêté', + }, + batchModal: { + title: 'Ajouter des lots de segments', + csvUploadTitle: 'Faites glisser et déposez votre fichier CSV ici, ou', + browse: 'parcourir', + tip: 'Le fichier CSV doit se conformer à la structure suivante :', + question: 'question', + answer: 'réponse', + contentTitle: 'contenu du bloc', + content: 'contenu', + template: 'Téléchargez le modèle ici', + cancel: 'Annuler', + run: 'Exécuter le lot', + runError: 'L\'exécution du lot a échoué', + processing: 'Dans le traitement par lots', + completed: 'Importation terminée', + error: 'Erreur d\'Importation', + ok: 'D\'accord', + }, + }, + metadata: { + title: 'Métadonnées', + desc: 'L\'étiquetage des métadonnées pour les documents permet à l\'IA d\'y accéder en temps opportun et expose la source des références pour les utilisateurs.', + dateTimeFormat: 'MMMM D, YYYY hh:mm A', + docTypeSelectTitle: 'Veuillez sélectionner un type de document', + docTypeChangeTitle: 'Changer le type de document', + docTypeSelectWarning: + 'Si le type de document est modifié, les métadonnées actuellement remplies ne seront plus conservées', + firstMetaAction: 'Allons-y', + placeholder: { + add: 'Ajouter', + select: 'Sélectionner', + }, + source: { + upload_file: 'Télécharger le fichier', + notion: 'Synchroniser le formulaire depuis Notion', + github: 'Synchroniser à partir de Github', + }, + type: { + book: 'Livre', + webPage: 'Page Web', + paper: 'Papier', + socialMediaPost: 'Publication sur les Réseaux Sociaux', + personalDocument: 'Document Personnel', + businessDocument: 'Document Commercial', + IMChat: 'Chat IM', + wikipediaEntry: 'Entrée Wikipédia', + notion: 'Synchroniser depuis Notion', + github: 'Synchroniser depuis Github', + technicalParameters: 'Paramètres Techniques', + }, + field: { + processRule: { + processDoc: 'Document de Processus', + segmentRule: 'Règle de Segment', + segmentLength: 'Longueur des Morceaux', + processClean: 'Processus de Nettoyage du Texte', + }, + book: { + title: 'Titre', + language: 'Langue', + author: 'Auteur', + publisher: 'Éditeur', + publicationDate: 'Date de publication', + ISBN: 'ISBN', + category: 'Catégorie', + }, + webPage: { + title: 'Titre', + url: 'URL', + language: 'Langue', + authorPublisher: 'Auteur/Éditeur', + publishDate: 'Date de publication', + topicsKeywords: 'Sujets/Mots-clés', + description: 'Description', + }, + paper: { + title: 'Titre', + language: 'Langue', + author: 'Auteur', + publishDate: 'Date de publication', + journalConferenceName: 'Nom du Journal/Conférence', + volumeIssuePage: 'Volume/Numéro/Page', + DOI: 'DOI', + topicsKeywords: 'Sujets/Mots-clés', + abstract: 'Résumé', + }, + socialMediaPost: { + platform: 'Plateforme', + authorUsername: 'Auteur/Nom d\'utilisateur', + publishDate: 'Date de publication', + postURL: 'URL de publication', + topicsTags: 'Sujets/Tags', + }, + personalDocument: { + title: 'Titre', + author: 'Auteur', + creationDate: 'Date de Création', + lastModifiedDate: 'Date de Dernière Modification', + documentType: 'Type de Document', + tagsCategory: 'Tags/Catégorie', + }, + businessDocument: { + title: 'Titre', + author: 'Auteur', + creationDate: 'Date de création', + lastModifiedDate: 'Date de Dernière Modification', + documentType: 'Type de Document', + departmentTeam: 'Département/Équipe', + }, + IMChat: { + chatPlatform: 'Plateforme de Chat', + chatPartiesGroupName: 'Nom du groupe/Parties de discussion', + participants: 'Participants', + startDate: 'Date de Début', + endDate: 'Date de fin', + topicsKeywords: 'Sujets/Mots-clés', + fileType: 'Type de fichier', + }, + wikipediaEntry: { + title: 'Titre', + language: 'Langue', + webpageURL: 'URL de la page web', + editorContributor: 'Éditeur/Contributeur', + lastEditDate: 'Date de dernière modification', + summaryIntroduction: 'Résumé/Introduction', + }, + notion: { + title: 'Titre', + language: 'Langue', + author: 'Auteur', + createdTime: 'Heure de création', + lastModifiedTime: 'Dernière Modification', + url: 'URL', + tag: 'Étiquette', + description: 'Description', + }, + github: { + repoName: 'Nom du dépôt', + repoDesc: 'Description du dépôt', + repoOwner: 'Propriétaire du dépôt', + fileName: 'Nom du Fichier', + filePath: 'Chemin du fichier', + programmingLang: 'Langage de programmation', + url: 'URL', + license: 'Licence', + lastCommitTime: 'Heure du dernier commit', + lastCommitAuthor: 'Auteur du dernier commit', + }, + originInfo: { + originalFilename: 'Nom de fichier original', + originalFileSize: 'Taille originale du fichier', + uploadDate: 'Date de téléchargement', + lastUpdateDate: 'Date de dernière mise à jour', + source: 'Source', + }, + technicalParameters: { + segmentSpecification: 'Spécification des morceaux', + segmentLength: 'Longueur des morceaux', + avgParagraphLength: 'Longueur moyenne de paragraphe', + paragraphs: 'Paragraphes', + hitCount: 'Nombre de récupérations', + embeddingTime: 'Temps d\'incorporation', + embeddedSpend: 'Dépenses intégrées', + }, + }, + languageMap: { + zh: 'Chinois', + en: 'Anglais', + es: 'Espagnol', + fr: 'Français', + de: 'Allemand', + ja: 'Japonais', + ko: 'Coréen', + ru: 'Russe', + ar: 'Arabe', + pt: 'Portugais', + it: 'Italien', + nl: 'Néerlandais', + pl: 'Polonais', + sv: 'Suédois', + tr: 'Turc', + he: 'Hébreu', + hi: 'Hindi', + da: 'Danois', + fi: 'Finlandais', + no: 'Norvégien', + hu: 'Hongrois', + el: 'Grec', + cs: 'Tchèque', + th: 'Thaï', + id: 'Indonésien', + }, + categoryMap: { + book: { + fiction: 'Fiction', + biography: 'Biographie', + history: 'Histoire', + science: 'Science', + technology: 'Technologie', + education: 'Éducation', + philosophy: 'Philosophie', + religion: 'Religion', + socialSciences: 'Sciences Sociales', + art: 'Art', + travel: 'Voyage', + health: 'Santé', + selfHelp: 'AutoAssistance', + businessEconomics: 'Économie d\'entreprise', + cooking: 'Cuisson', + childrenYoungAdults: 'EnfantsJeunesAdultes', + comicsGraphicNovels: 'BandesDessinéesRomansGraphiques', + poetry: 'Poésie', + drama: 'Drame', + other: 'Autre', + }, + personalDoc: { + notes: 'Notes', + blogDraft: 'Brouillon de Blog', + diary: 'Journal', + researchReport: 'Rapport de Recherche', + bookExcerpt: 'Extrait de livre', + schedule: 'Programme', + list: 'Liste', + projectOverview: 'Aperçu du Projet', + photoCollection: 'Collection de Photos', + creativeWriting: 'Écriture Créative', + codeSnippet: 'Extrait de Code', + designDraft: 'Projet de Conception', + personalResume: 'Curriculum Vitae Personnel', + other: 'Autre', + }, + businessDoc: { + meetingMinutes: 'Compte-rendu de Réunion', + researchReport: 'Rapport de Recherche', + proposal: 'Proposition', + employeeHandbook: 'Manuel de l\'employé', + trainingMaterials: 'Matériaux de Formation', + requirementsDocument: 'Document de Spécifications', + designDocument: 'Document de Conception', + productSpecification: 'Spécification du produit', + financialReport: 'Rapport Financier', + marketAnalysis: 'Analyse de marché', + projectPlan: 'Plan de Projet', + teamStructure: 'Structure de l\'équipe', + policiesProcedures: 'Politiques & Procédures', + contractsAgreements: 'Contrats & Accords', + emailCorrespondence: 'Correspondance par Email', + other: 'Autre', + }, + }, + }, + embedding: { + processing: 'Traitement des embeddings...', + paused: 'Intégration en pause', + completed: 'Intégration terminée', + error: 'Erreur d\'embedding', + docName: 'Prétraitement du document', + mode: 'Règle de segmentation', + segmentLength: 'Longueur des morceaux', + textCleaning: 'Pré-définition du texte et nettoyage', + segments: 'Paragraphes', + highQuality: 'Mode haute qualité', + economy: 'Mode économique', + estimate: 'Consommation estimée', + stop: 'Arrêtez le traitement', + resume: 'Reprendre le traitement', + automatic: 'Automatique', + custom: 'Personnalisé', + previewTip: 'L\'aperçu du paragraphe sera disponible après la fin de l\'embedding.', + }, + segment: { + paragraphs: 'Paragraphes', + keywords: 'Mots Clés', + addKeyWord: 'Ajouter un mot-clé', + keywordError: 'La longueur maximale du mot-clé est de 20', + characters: 'personnages', + hitCount: 'Nombre de récupérations', + vectorHash: 'Vector hash: ', + questionPlaceholder: 'ajoutez la question ici', + questionEmpty: 'La question ne peut pas être vide', + answerPlaceholder: 'ajoutez une réponse ici', + answerEmpty: 'La réponse ne peut pas être vide', + contentPlaceholder: 'ajoutez du contenu ici', + contentEmpty: 'Le contenu ne peut pas être vide', + newTextSegment: 'Nouveau Segment de Texte', + newQaSegment: 'Nouveau Segment Q&R', + delete: 'Supprimer ce morceau ?', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/dataset-hit-testing.ts b/web/i18n/fr-FR/dataset-hit-testing.ts new file mode 100644 index 0000000000..1e527bcf24 --- /dev/null +++ b/web/i18n/fr-FR/dataset-hit-testing.ts @@ -0,0 +1,28 @@ +const translation = { + title: 'Test de Récupération', + desc: 'Testez l\'effet d\'impact de la Connaissance basée sur le texte de la requête donnée.', + dateTimeFormat: 'JJ/MM/AAAA hh:mm A', + recents: 'Récents', + table: { + header: { + source: 'Source', + text: 'Texte', + time: 'Temps', + }, + }, + input: { + title: 'Texte source', + placeholder: 'Veuillez entrer un texte, une phrase déclarative courte est recommandée.', + countWarning: 'Jusqu\'à 200 caractères.', + indexWarning: 'Connaissances de haute qualité uniquement.', + testing: 'Test', + }, + hit: { + title: 'PARAGRAPHES DE RÉCUPÉRATION', + emptyTip: 'Les résultats des tests de récupération s\'afficheront ici', + }, + noRecentTip: 'Aucun résultat de requête récent ici', + viewChart: 'Voir GRAPHIQUE VECTORIEL', +} + +export default translation diff --git a/web/i18n/fr-FR/dataset-settings.ts b/web/i18n/fr-FR/dataset-settings.ts new file mode 100644 index 0000000000..731013738a --- /dev/null +++ b/web/i18n/fr-FR/dataset-settings.ts @@ -0,0 +1,33 @@ +const translation = { + title: 'Paramètres de connaissance', + desc: 'Ici, vous pouvez modifier les propriétés et les méthodes de fonctionnement de la Connaissance.', + form: { + name: 'Nom de la Connaissance', + namePlaceholder: 'Veuillez entrer le nom de la Connaissance', + nameError: 'Le nom ne peut pas être vide', + desc: 'Description des connaissances', + descInfo: 'Veuillez rédiger une description textuelle claire pour décrire le contenu de la Connaissance. Cette description sera utilisée comme base pour la correspondance lors de la sélection parmi plusieurs Connaissances pour l\'inférence.', + descPlaceholder: 'Décrivez ce qui se trouve dans cette Connaissance. Une description détaillée permet à l\'IA d\'accéder au contenu de la Connaissance en temps opportun. Si vide, Dify utilisera la stratégie de hit par défaut.', + descWrite: 'Apprenez comment rédiger une bonne description de connaissance.', + permissions: 'Autorisations', + permissionsOnlyMe: 'Seulement moi', + permissionsAllMember: 'Tous les membres de l\'équipe', + indexMethod: 'Méthode d\'Indexation', + indexMethodHighQuality: 'Haute Qualité', + indexMethodHighQualityTip: 'Appelez l\'interface d\'embedding d\'OpenAI pour le traitement afin de fournir une précision plus élevée lorsque les utilisateurs font une requête.', + indexMethodEconomy: 'Économique', + indexMethodEconomyTip: 'Utilisez des moteurs vectoriels hors ligne, des index de mots-clés, etc. pour réduire la précision sans dépenser de jetons', + embeddingModel: 'Modèle d\'Embedding', + embeddingModelTip: 'Changez le modèle intégré, veuillez aller à', + embeddingModelTipLink: 'Paramètres', + retrievalSetting: { + title: 'Paramètre de récupération', + learnMore: 'En savoir plus', + description: 'à propos de la méthode de récupération.', + longDescription: 'À propos de la méthode de récupération, vous pouvez la modifier à tout moment dans les paramètres de Connaissance.', + }, + save: 'Enregistrer', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/dataset.ts b/web/i18n/fr-FR/dataset.ts new file mode 100644 index 0000000000..2d649d914b --- /dev/null +++ b/web/i18n/fr-FR/dataset.ts @@ -0,0 +1,47 @@ +const translation = { + knowledge: 'Connaissance', + documentCount: ' documents', + wordCount: 'k mots', + appCount: ' applications liées', + createDataset: 'Créer des Connaissances', + createDatasetIntro: 'Importez vos propres données textuelles ou écrivez des données en temps réel via Webhook pour l\'amélioration du contexte LLM.', + deleteDatasetConfirmTitle: 'Supprimer cette Connaissance ?', + deleteDatasetConfirmContent: + 'La suppression de la Connaissance est irréversible. Les utilisateurs ne pourront plus accéder à votre Savoir, et toutes les configurations de prompt et les journaux seront supprimés de façon permanente.', + datasetDeleted: 'Connaissance supprimée', + datasetDeleteFailed: 'Échec de la suppression de la Connaissance', + didYouKnow: 'Saviez-vous ?', + intro1: 'La Connaissance peut être intégrée dans l\'application Dify', + intro2: 'comme un contexte', + intro3: ',', + intro4: 'ou ça ', + intro5: 'peut être créé', + intro6: 'comme un plug-in d\'index ChatGPT autonome à publier', + unavailable: 'Indisponible', + unavailableTip: 'Le modèle d\'embedding n\'est pas disponible, le modèle d\'embedding par défaut doit être configuré', + datasets: 'CONNAISSANCE', + datasetsApi: 'API', + retrieval: { + semantic_search: { + title: 'Recherche Vectorielle', + description: 'Générez des embeddings de requête et recherchez le morceau de texte le plus similaire à sa représentation vectorielle.', + }, + full_text_search: { + title: 'Recherche en Texte Intégral', + description: 'Indexez tous les termes dans le document, permettant aux utilisateurs de rechercher n\'importe quel terme et de récupérer le fragment de texte pertinent contenant ces termes.', + }, + hybrid_search: { + title: 'Recherche Hybride', + description: 'Exécutez une recherche en texte intégral et des recherches vectorielles en même temps, réorganisez pour sélectionner la meilleure correspondance pour la requête de l\'utilisateur. La configuration de l\'API du modèle de réorganisation est nécessaire.', + recommend: 'Recommander', + }, + invertedIndex: { + title: 'Index inversé', + description: 'L\'Index inversé est une structure utilisée pour une récupération efficace. Organisé par termes, chaque terme pointe vers des documents ou des pages web le contenant.', + }, + change: 'Changer', + changeRetrievalMethod: 'Changer la méthode de récupération', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/explore.ts b/web/i18n/fr-FR/explore.ts new file mode 100644 index 0000000000..77d17b3f03 --- /dev/null +++ b/web/i18n/fr-FR/explore.ts @@ -0,0 +1,41 @@ +const translation = { + title: 'Explorer', + sidebar: { + discovery: 'Découverte', + chat: 'Discussion', + workspace: 'Espace de travail', + action: { + pin: 'Épingle', + unpin: 'Détacher', + rename: 'Renommer', + delete: 'Supprimer', + }, + delete: { + title: 'Supprimer l\'application', + content: 'Êtes-vous sûr de vouloir supprimer cette application ?', + }, + }, + apps: { + title: 'Explorez les applications par Dify', + description: 'Utilisez ces applications modèles instantanément ou personnalisez vos propres applications basées sur les modèles.', + allCategories: 'Toutes les catégories', + }, + appCard: { + addToWorkspace: 'Ajouter à l\'espace de travail', + customize: 'Personnaliser', + }, + appCustomize: { + title: 'Créer une application à partir de {{name}}', + subTitle: 'Icône de l\'application & nom', + nameRequired: 'Le nom de l\'application est requis', + }, + category: { + Assistant: 'Assistant', + Writing: 'Écriture', + Translate: 'Traduire', + Programming: 'Programmation', + HR: 'RH', + }, +} + +export default translation diff --git a/web/i18n/fr-FR/layout.ts b/web/i18n/fr-FR/layout.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/fr-FR/layout.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/fr-FR/login.ts b/web/i18n/fr-FR/login.ts new file mode 100644 index 0000000000..38e4da2b18 --- /dev/null +++ b/web/i18n/fr-FR/login.ts @@ -0,0 +1,59 @@ +const translation = { + pageTitle: 'Salut, commençons !👋', + welcome: 'Bienvenue sur Dify, veuillez vous connecter pour continuer.', + email: 'Adresse e-mail', + emailPlaceholder: 'Votre email', + password: 'Mot de passe', + passwordPlaceholder: 'Votre mot de passe', + name: 'Nom d\'utilisateur', + namePlaceholder: 'Votre nom d\'utilisateur', + forget: 'Mot de passe oublié ?', + signBtn: 'Se connecter', + installBtn: 'Mettre en place', + setAdminAccount: 'Configuration d\'un compte administrateur', + setAdminAccountDesc: 'Privilèges maximum pour le compte administrateur, qui peut être utilisé pour créer des applications et gérer les fournisseurs de LLM, etc.', + createAndSignIn: 'Créer et se connecter', + oneMoreStep: 'Une étape de plus', + createSample: 'Sur la base de ces informations, nous créerons une application exemple pour vous', + invitationCode: 'Code d\'invitation', + invitationCodePlaceholder: 'Votre code d\'invitation', + interfaceLanguage: 'Langue de l\'interface', + timezone: 'Fuseau horaire', + go: 'Aller à Dify', + sendUsMail: 'Envoyez-nous votre introduction, et nous nous occuperons de la demande d\'invitation.', + acceptPP: 'J\'ai lu et j\'accepte la politique de confidentialité', + reset: 'Veuillez exécuter la commande suivante pour réinitialiser votre mot de passe', + withGitHub: 'Continuer avec GitHub', + withGoogle: 'Continuer avec Google', + rightTitle: 'Débloquez le plein potentiel des LLM', + rightDesc: 'Construisez sans effort des applications IA visuellement captivantes, opérationnelles et améliorables.', + tos: 'Conditions de Service', + pp: 'Politique de Confidentialité', + tosDesc: 'En vous inscrivant, vous acceptez nos', + donthave: 'Vous n\'avez pas ?', + invalidInvitationCode: 'Code d\'invitation invalide', + accountAlreadyInited: 'Compte déjà initialisé', + error: { + emailEmpty: 'Une adresse e-mail est requise', + emailInValid: 'Veuillez entrer une adresse email valide', + nameEmpty: 'Le nom est requis', + passwordEmpty: 'Un mot de passe est requis', + passwordInvalid: 'Le mot de passe doit contenir des lettres et des chiffres, et la longueur doit être supérieure à 8.', + }, + license: { + tip: 'Avant de commencer Dify Community Edition, lisez le GitHub', + link: 'Licence Open-source', + }, + join: 'Rejoindre', + joinTipStart: 'Je vous invite à rejoindre', + joinTipEnd: 'équipe sur Dify', + invalid: 'Le lien a expiré', + explore: 'Explorez Dify', + activatedTipStart: 'Vous avez rejoint le', + activatedTipEnd: 'équipe', + activated: 'Connectez-vous maintenant', + adminInitPassword: 'Mot de passe d\'initialisation de l\'administrateur', + validate: 'Valider', +} + +export default translation diff --git a/web/i18n/fr-FR/register.ts b/web/i18n/fr-FR/register.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/fr-FR/register.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/fr-FR/share-app.ts b/web/i18n/fr-FR/share-app.ts new file mode 100644 index 0000000000..97f0d992e0 --- /dev/null +++ b/web/i18n/fr-FR/share-app.ts @@ -0,0 +1,74 @@ +const translation = { + common: { + welcome: 'Bienvenue à l\'utilisation', + appUnavailable: 'L\'application n\'est pas disponible', + appUnkonwError: 'L\'application n\'est pas disponible', + }, + chat: { + newChat: 'Nouveau chat', + pinnedTitle: 'Épinglé', + unpinnedTitle: 'Discussions', + newChatDefaultName: 'Nouvelle conversation', + resetChat: 'Réinitialiser la conversation', + powerBy: 'Propulsé par', + prompt: 'Prompt', + privatePromptConfigTitle: 'Paramètres de conversation', + publicPromptConfigTitle: 'Prompt Initial', + configStatusDes: 'Avant de commencer, vous pouvez modifier les paramètres de conversation', + configDisabled: + 'Les paramètres de la session précédente ont été utilisés pour cette session.', + startChat: 'Commencer le Chat', + privacyPolicyLeft: + 'Veuillez lire', + privacyPolicyMiddle: + 'politique de confidentialité', + privacyPolicyRight: + 'fourni par le développeur de l\'application.', + deleteConversation: { + title: 'Supprimer la conversation', + content: 'Êtes-vous sûr de vouloir supprimer cette conversation ?', + }, + tryToSolve: 'Essayez de résoudre', + temporarySystemIssue: 'Désolé, problème temporaire du système.', + }, + generation: { + tabs: { + create: 'Exécuter une fois', + batch: 'Exécuter le lot', + saved: 'Enregistré', + }, + savedNoData: { + title: 'Vous n\'avez pas encore enregistré de résultat !', + description: 'Commencez à générer du contenu et retrouvez vos résultats sauvegardés ici.', + startCreateContent: 'Commencez à créer du contenu', + }, + title: 'Complétion IA', + queryTitle: 'Contenu de la requête', + completionResult: 'Résultat de la complétion', + queryPlaceholder: 'Rédigez le contenu de votre requête...', + run: 'Exécuter', + copy: 'Copier', + resultTitle: 'Complétion IA', + noData: 'L\'IA vous donnera ce que vous voulez ici.', + csvUploadTitle: 'Faites glisser et déposez votre fichier CSV ici, ou', + browse: 'parcourir', + csvStructureTitle: 'Le fichier CSV doit se conformer à la structure suivante :', + downloadTemplate: 'Téléchargez le modèle ici', + field: 'Champ', + batchFailed: { + info: '{{num}} exécutions échouées', + retry: 'Réessayer', + outputPlaceholder: 'Aucun contenu de sortie', + }, + errorMsg: { + empty: 'Veuillez entrer le contenu dans le fichier téléchargé.', + fileStructNotMatch: 'Le fichier CSV téléchargé ne correspond pas à la structure.', + emptyLine: 'La ligne {{rowIndex}} est vide', + invalidLine: 'Row {{rowIndex}}: {{varName}} value can not be empty', + moreThanMaxLengthLine: 'Row {{rowIndex}}: {{varName}} value can not be more than {{maxLength}} characters', + atLeastOne: 'Veuillez entrer au moins une ligne dans le fichier téléchargé.', + }, + }, +} + +export default translation diff --git a/web/i18n/fr-FR/tools.ts b/web/i18n/fr-FR/tools.ts new file mode 100644 index 0000000000..f0819a201a --- /dev/null +++ b/web/i18n/fr-FR/tools.ts @@ -0,0 +1,115 @@ +const translation = { + title: 'Outils', + createCustomTool: 'Créer un Outil Personnalisé', + type: { + all: 'Tout', + builtIn: 'Intégré', + custom: 'Personnalisé', + }, + contribute: { + line1: 'Je suis intéressé par', + line2: 'contribuer des outils à Dify.', + viewGuide: 'Voir le guide', + }, + author: 'Par', + auth: { + unauthorized: 'Pour Autoriser', + authorized: 'Autorisé', + setup: 'Mettez en place l\'autorisation à utiliser', + setupModalTitle: 'Configurer l\'Autorisation', + setupModalTitleDescription: 'Après avoir configuré les identifiants, tous les membres de l\'espace de travail peuvent utiliser cet outil lors de l\'orchestration des applications.', + }, + includeToolNum: '{{num}} outils inclus', + addTool: 'Ajouter un outil', + createTool: { + title: 'Créer un Outil Personnalisé', + editAction: 'Configurer', + editTitle: 'Modifier l\'Outil Personnalisé', + name: 'Nom', + toolNamePlaceHolder: 'Entrez le nom de l\'outil', + schema: 'Schéma', + schemaPlaceHolder: 'Entrez votre schéma OpenAPI ici', + viewSchemaSpec: 'Voir la spécification OpenAPI-Swagger', + importFromUrl: 'Importer depuis l\'URL', + importFromUrlPlaceHolder: 'https://...', + urlError: 'Veuillez entrer une URL valide', + examples: 'Exemples', + exampleOptions: { + json: 'Météo(JSON)', + yaml: 'Animalerie (YAML)', + blankTemplate: 'Modèle Vierge', + }, + availableTools: { + title: 'Outils Disponibles', + name: 'Nom', + description: 'Description', + method: 'Méthode', + path: 'Chemin', + action: 'Actions', + test: 'Test', + }, + authMethod: { + title: 'Méthode d\'autorisation', + type: 'Type d\'autorisation', + keyTooltip: 'Clé de l\'en-tête HTTP. Vous pouvez la laisser telle quelle avec "Autorisation" si vous n\'avez aucune idée de ce que c\'est, ou la définir sur une valeur personnalisée.', + types: { + none: 'Aucun', + api_key: 'Clé API', + apiKeyPlaceholder: 'Nom de l\'en-tête HTTP pour la clé API', + apiValuePlaceholder: 'Entrez la clé API', + }, + key: 'Clé', + value: 'Valeur', + }, + authHeaderPrefix: { + title: 'Type d\'Authentification', + types: { + basic: 'Basique', + bearer: 'Porteur', + custom: 'Personnalisé', + }, + }, + privacyPolicy: 'Politique de confidentialité', + privacyPolicyPlaceholder: 'Veuillez entrer la politique de confidentialité', + }, + test: { + title: 'Test', + parametersValue: 'Paramètres & Valeur', + parameters: 'Paramètres', + value: 'Valeur', + testResult: 'Résultats du Test', + testResultPlaceholder: 'Le résultat du test s\'affichera ici', + }, + thought: { + using: 'Utilisation', + used: 'Utilisé', + requestTitle: 'Demande à', + responseTitle: 'Réponse de', + }, + setBuiltInTools: { + info: 'Infos', + setting: 'Paramètres', + toolDescription: 'Description de l\'outil', + parameters: 'paramètres', + string: 'chaîne', + number: 'nombre', + required: 'Requis', + infoAndSetting: 'Infos & Paramètres', + }, + noCustomTool: { + title: 'Pas d\'outils personnalisés !', + content: 'Ajoutez et gérez vos outils personnalisés ici pour construire des applications IA.', + createTool: 'Créer un outil', + }, + noSearchRes: { + title: 'Désolé, aucun résultat !', + content: 'Nous n\'avons trouvé aucun outil correspondant à votre recherche.', + reset: 'Réinitialiser la recherche', + }, + builtInPromptTitle: 'Invite', + toolRemoved: 'Outil supprimé', + notAuthorized: 'Outil non autorisé', + howToGet: 'Comment obtenir', +} + +export default translation diff --git a/web/i18n/language.ts b/web/i18n/language.ts index c72717dc84..f885289e07 100644 --- a/web/i18n/language.ts +++ b/web/i18n/language.ts @@ -48,7 +48,7 @@ export const languages = [ value: 'fr-FR', name: 'Français(France)', example: 'Bonjour, Dify!', - supported: false, + supported: true, }, { value: 'de-DE', From 8b15b742ade2713d9623db8ad7cd01e80881421c Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Wed, 13 Mar 2024 20:29:38 +0800 Subject: [PATCH 026/200] generalize position helper for parsing _position.yaml and sorting objects by name (#2803) --- api/core/extension/extensible.py | 14 ++-- .../model_providers/__base/ai_model.py | 14 +--- .../model_providers/model_provider_factory.py | 22 ++---- api/core/tools/provider/builtin/_positions.py | 21 +++--- api/core/utils/position_helper.py | 70 +++++++++++++++++++ 5 files changed, 95 insertions(+), 46 deletions(-) create mode 100644 api/core/utils/position_helper.py diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c19aaefe9e..1809dcd8df 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,11 +3,12 @@ import importlib.util import json import logging import os -from collections import OrderedDict from typing import Any, Optional from pydantic import BaseModel +from core.utils.position_helper import sort_to_dict_by_position_map + class ExtensionModule(enum.Enum): MODERATION = 'moderation' @@ -36,7 +37,8 @@ class Extensible: @classmethod def scan_extensions(cls): - extensions = {} + extensions: list[ModuleExtension] = [] + position_map = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') @@ -63,6 +65,7 @@ class Extensible: if os.path.exists(builtin_file_path): with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) + position_map[extension_name] = position if (extension_name + '.py') not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") @@ -96,16 +99,15 @@ class Extensible: with open(json_path, encoding='utf-8') as f: json_data = json.load(f) - extensions[extension_name] = ModuleExtension( + extensions.append(ModuleExtension( extension_class=extension_class, name=extension_name, label=json_data.get('label'), form_schema=json_data.get('form_schema'), builtin=builtin, position=position - ) + )) - sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) - sorted_extensions = OrderedDict(sorted_items) + sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) return sorted_extensions diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 026e6eca21..34a7375493 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.utils.position_helper import get_position_map, sort_by_position_map class AIModel(ABC): @@ -148,15 +149,7 @@ class AIModel(ABC): ] # get _position.yaml file path - position_file_path = os.path.join(provider_model_type_path, '_position.yaml') - - # read _position.yaml file - position_map = {} - if os.path.exists(position_file_path): - with open(position_file_path, encoding='utf-8') as f: - positions = yaml.safe_load(f) - # convert list to dict with key as model provider name, value as index - position_map = {position: index for index, position in enumerate(positions)} + position_map = get_position_map(provider_model_type_path) # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: @@ -206,8 +199,7 @@ class AIModel(ABC): model_schemas.append(model_schema) # resort model schemas by position - if position_map: - model_schemas.sort(key=lambda x: position_map.get(x.model, 999)) + model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model) # cache model schemas self.model_schemas = model_schemas diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 185ff62711..ee0385c6d0 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,10 +1,8 @@ import importlib import logging import os -from collections import OrderedDict from typing import Optional -import yaml from pydantic import BaseModel from core.model_runtime.entities.model_entities import ModelType @@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map logger = logging.getLogger(__name__) @@ -200,7 +199,6 @@ class ModelProviderFactory: if self.model_provider_extensions: return self.model_provider_extensions - model_providers = {} # get the path of current classes current_path = os.path.abspath(__file__) @@ -215,17 +213,10 @@ class ModelProviderFactory: ] # get _position.yaml file path - position_file_path = os.path.join(model_providers_path, '_position.yaml') - - # read _position.yaml file - position_map = {} - if os.path.exists(position_file_path): - with open(position_file_path, encoding='utf-8') as f: - positions = yaml.safe_load(f) - # convert list to dict with key as model provider name, value as index - position_map = {position: index for index, position in enumerate(positions)} + position_map = get_position_map(model_providers_path) # traverse all model_provider_dir_paths + model_providers: list[ModelProviderExtension] = [] for model_provider_dir_path in model_provider_dir_paths: # get model_provider dir name model_provider_name = os.path.basename(model_provider_dir_path) @@ -256,14 +247,13 @@ class ModelProviderFactory: logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") continue - model_providers[model_provider_name] = ModelProviderExtension( + model_providers.append(ModelProviderExtension( name=model_provider_name, provider_instance=model_provider_class(), position=position_map.get(model_provider_name) - ) + )) - sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position)) - sorted_extensions = OrderedDict(sorted_items) + sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) self.model_provider_extensions = sorted_extensions diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index fa2c5d27ef..2bf70bd356 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,8 +1,7 @@ import os.path -from yaml import FullLoader, load - from core.tools.entities.user_entities import UserToolProvider +from core.utils.position_helper import get_position_map, sort_by_position_map class BuiltinToolProviderSort: @@ -11,18 +10,14 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - tmp_position = {} - file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') - with open(file_path) as f: - for pos, val in enumerate(load(f, Loader=FullLoader)): - tmp_position[val] = pos - cls._position = tmp_position + cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) - def sort_compare(provider: UserToolProvider) -> int: + def name_func(provider: UserToolProvider) -> str: if provider.type == UserToolProvider.ProviderType.MODEL: - return cls._position.get(f'model.{provider.name}', 10000) - return cls._position.get(provider.name, 10000) - - sorted_providers = sorted(providers, key=sort_compare) + return f'model.{provider.name}' + else: + return provider.name + + sorted_providers = sort_by_position_map(cls._position, providers, name_func) return sorted_providers \ No newline at end of file diff --git a/api/core/utils/position_helper.py b/api/core/utils/position_helper.py new file mode 100644 index 0000000000..e038390e09 --- /dev/null +++ b/api/core/utils/position_helper.py @@ -0,0 +1,70 @@ +import logging +import os +from collections import OrderedDict +from collections.abc import Callable +from typing import Any, AnyStr + +import yaml + + +def get_position_map( + folder_path: AnyStr, + file_name: str = '_position.yaml', +) -> dict[str, int]: + """ + Get the mapping from name to index from a YAML file + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + try: + position_file_name = os.path.join(folder_path, file_name) + if not os.path.exists(position_file_name): + return {} + + with open(position_file_name, encoding='utf-8') as f: + positions = yaml.safe_load(f) + position_map = {} + for index, name in enumerate(positions): + if name and isinstance(name, str): + position_map[name.strip()] = index + return position_map + except: + logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') + return {} + + +def sort_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> list[Any]: + """ + Sort the objects by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: the sorted objects + """ + if not position_map or not data: + return data + + return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + + +def sort_to_dict_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> OrderedDict[str, Any]: + """ + Sort the objects into a ordered dict by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: an OrderedDict with the sorted pairs of name and object + """ + sorted_items = sort_by_position_map(position_map, data, name_func) + return OrderedDict([(name_func(item), item) for item in sorted_items]) From 1f92b55f5817169cc8bcdc3d53e6c935312b166c Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Wed, 13 Mar 2024 22:25:18 +0800 Subject: [PATCH 027/200] fix: doc for completion-messages (#2820) --- web/app/components/develop/template/template.en.mdx | 12 +++++++----- web/app/components/develop/template/template.zh.mdx | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index f930cfe1c9..d975716913 100644 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -42,13 +42,13 @@ The text generation application offers non-session support and is ideal for tran ### Request Body - - User Input/Question content - + Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. The text generation application requires at least one key/value pair to be inputted. + - `query` (string) Required + The input text, the content to be processed. The mode of response return, supporting: @@ -138,14 +138,16 @@ The text generation application offers non-session support and is ideal for tran - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/completion-messages' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "inputs": {}, + "inputs": { + "query": "Hello, world!" + }, "response_mode": "streaming", "user": "abc-123" }' diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index 8153906d0a..1b83c6a1c3 100644 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -41,13 +41,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ### Request Body - - 用户输入/提问内容。 - (选填)允许传入 App 定义的各变量值。 inputs 参数包含了多组键值对(Key/Value pairs),每组的键对应一个特定变量,每组的值则是该变量的具体值。 文本生成型应用要求至少传入一组键值对。 + - `query` (string) 必填 + 用户输入的文本内容。 - `streaming` 流式模式(推荐)。基于 SSE(**[Server-Sent Events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)**)实现类似打字机输出方式的流式返回。 @@ -140,18 +139,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/completion-messages' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "inputs": {}, + "inputs": { + "query": "Hello, world!" + }, "response_mode": "streaming", "user": "abc-123" }' ``` + ``` ### blocking From e52448b84b7452408dbfd03a9b07dd2cb1b8f63d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=86=E8=90=8C=E9=97=B7=E6=B2=B9=E7=93=B6?= <253605712@qq.com> Date: Thu, 14 Mar 2024 09:14:27 +0800 Subject: [PATCH 028/200] feat:add api-version selection for azure openai APIs (#2821) --- .../model_providers/azure_openai/_common.py | 3 ++- .../model_providers/azure_openai/_constant.py | 1 - .../azure_openai/azure_openai.yaml | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index b65138252b..31c788d226 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,11 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: + api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) credentials_kwargs = { "api_key": credentials['openai_api_key'], "azure_endpoint": credentials['openai_api_base'], - "api_version": AZURE_OPENAI_API_VERSION, + "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index dda0258812..4aa767fa1d 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -16,7 +16,6 @@ from core.model_runtime.entities.model_entities import ( AZURE_OPENAI_API_VERSION = '2024-02-15-preview' - def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( name='max_tokens', diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 58800ddee2..224f2a08a1 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -46,6 +46,22 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API key here + - variable: openai_api_version + label: + zh_Hans: API 版本 + en_US: API Version + type: select + required: true + options: + - label: + en_US: 2024-02-15-preview + value: 2024-02-15-preview + - label: + en_US: 2023-12-01-preview + value: 2023-12-01-preview + placeholder: + zh_Hans: 在此选择您的 API 版本 + en_US: Select your API Version here - variable: base_model_name label: en_US: Base Model From 4fe585acc23d1395c7e15010a282e9bef2c7d09c Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 14 Mar 2024 10:08:24 +0800 Subject: [PATCH 029/200] feat(llm/models): add claude-3-haiku-20240307 (#2825) --- .gitignore | 3 +- .../llm/claude-3-haiku-20240307.yaml | 37 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml diff --git a/.gitignore b/.gitignore index 658c9319b2..c957d63174 100644 --- a/.gitignore +++ b/.gitignore @@ -154,4 +154,5 @@ sdks/python-client/dist sdks/python-client/dify_client.egg-info .vscode/* -!.vscode/launch.json \ No newline at end of file +!.vscode/launch.json +pyrightconfig.json diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml new file mode 100644 index 0000000000..073d0c3a7d --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml @@ -0,0 +1,37 @@ +model: claude-3-haiku-20240307 +label: + en_US: claude-3-haiku-20240307 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '0.25' + output: '1.25' + unit: '0.000001' + currency: USD From 1e5455e266a17252f7105f11311d47ef99a6162d Mon Sep 17 00:00:00 2001 From: Qun <51054082+QunBB@users.noreply.github.com> Date: Thu, 14 Mar 2024 15:26:07 +0800 Subject: [PATCH 030/200] enhance: use override_settings for concurrent stable diffusion (#2818) --- .../stablediffusion/tools/stable_diffusion.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index e449062718..4c022f983f 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -131,7 +131,8 @@ class StableDiffusionTool(BuiltinTool): negative_prompt=negative_prompt, width=width, height=height, - steps=steps) + steps=steps, + model=model) return self.text2img(base_url=base_url, lora=lora, @@ -139,7 +140,8 @@ class StableDiffusionTool(BuiltinTool): negative_prompt=negative_prompt, width=width, height=height, - steps=steps) + steps=steps, + model=model) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ @@ -197,7 +199,7 @@ class StableDiffusionTool(BuiltinTool): def img2img(self, base_url: str, lora: str, image_binary: bytes, prompt: str, negative_prompt: str, - width: int, height: int, steps: int) \ + width: int, height: int, steps: int, model: str) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image @@ -213,7 +215,8 @@ class StableDiffusionTool(BuiltinTool): "sampler_name": "Euler a", "restore_faces": False, "steps": steps, - "script_args": ["outpainting mk2"] + "script_args": ["outpainting mk2"], + "override_settings": {"sd_model_checkpoint": model} } if lora: @@ -236,7 +239,7 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: return self.create_text_message('Failed to generate image') - def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \ + def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ generate image @@ -253,6 +256,7 @@ class StableDiffusionTool(BuiltinTool): draw_options['height'] = height draw_options['steps'] = steps draw_options['negative_prompt'] = negative_prompt + draw_options['override_settings']['sd_model_checkpoint'] = model try: url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') From 5cab2b711f6fce48da9d15e279175a452ed0d8b4 Mon Sep 17 00:00:00 2001 From: ChiayenGu <47971475+ChiayenGu@users.noreply.github.com> Date: Thu, 14 Mar 2024 16:41:40 +0800 Subject: [PATCH 031/200] fix: doc for datasets (#2831) --- .../(commonLayout)/datasets/template/template.en.mdx | 8 ++++---- .../(commonLayout)/datasets/template/template.zh.mdx | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index e3f3f21e6b..786b3277d0 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -178,10 +178,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/document/create_by_file" - targetCode={`curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} > ```bash {{ title: 'cURL' }} - curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -491,10 +491,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/documents/{document_id}/update_by_file" - targetCode={`curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"name":"Dify","indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"name":"Dify","indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} > ```bash {{ title: 'cURL' }} - curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index c176041925..f0bf12fac5 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -76,7 +76,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "text","text": "text","indexing_technique": "high_quality","process_rule": {"mode": "automatic"}}'`} > ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -178,10 +178,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/document/create_by_file" - targetCode={`curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} > ```bash {{ title: 'cURL' }} - curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -491,10 +491,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from title="Request" tag="POST" label="/datasets/{dataset_id}/documents/{document_id}/update_by_file" - targetCode={`curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"name":"Dify","indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} + targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \\\n--header 'Authorization: Bearer {api_key}' \\\n--form 'data="{"name":"Dify","indexing_technique":"high_quality","process_rule":{"rules":{"pre_processing_rules":[{"id":"remove_extra_spaces","enabled":true},{"id":"remove_urls_emails","enabled":true}],"segmentation":{"separator":"###","max_tokens":500}},"mode":"custom"}}";type=text/plain' \\\n--form 'file=@"/path/to/file"'`} > ```bash {{ title: 'cURL' }} - curl --location POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/{document_id}/create_by_file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' From 19d3a56194fdd8030bf7c08b11db5a8f50ac98b2 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 14 Mar 2024 17:01:48 +0800 Subject: [PATCH 032/200] feat: add weekday calculator in time tool (#2822) --- .../provider/builtin/time/tools/weekday.py | 42 +++++++++++++++++++ .../provider/builtin/time/tools/weekday.yaml | 42 +++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 api/core/tools/provider/builtin/time/tools/weekday.py create mode 100644 api/core/tools/provider/builtin/time/tools/weekday.yaml diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py new file mode 100644 index 0000000000..4461cb5a32 --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -0,0 +1,42 @@ +import calendar +from datetime import datetime +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class WeekdayTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Calculate the day of the week for a given date + """ + year = tool_parameters.get('year') + month = tool_parameters.get('month') + day = tool_parameters.get('day') + + date_obj = self.convert_datetime(year, month, day) + if not date_obj: + return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + + weekday_name = calendar.day_name[date_obj.weekday()] + month_name = calendar.month_name[month] + readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" + return self.create_text_message(f'{readable_date} is {weekday_name}.') + + @staticmethod + def convert_datetime(year, month, day) -> datetime | None: + try: + # allowed range in datetime module + if not (year >= 1 and 1 <= month <= 12 and 1 <= day <= 31): + return None + + year = int(year) + month = int(month) + day = int(day) + return datetime(year, month, day) + except ValueError: + return None diff --git a/api/core/tools/provider/builtin/time/tools/weekday.yaml b/api/core/tools/provider/builtin/time/tools/weekday.yaml new file mode 100644 index 0000000000..481585e8c9 --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/weekday.yaml @@ -0,0 +1,42 @@ +identity: + name: weekday + author: Bowen Liang + label: + en_US: Weekday Calculator + zh_Hans: 星期几计算器 +description: + human: + en_US: A tool for calculating the weekday of a given date. + zh_Hans: 计算指定日期为星期几的工具。 + llm: A tool for calculating the weekday of a given date by year, month and day. +parameters: + - name: year + type: number + required: true + form: llm + label: + en_US: Year + zh_Hans: 年 + human_description: + en_US: Year + zh_Hans: 年 + - name: month + type: number + required: true + form: llm + label: + en_US: Month + zh_Hans: 月 + human_description: + en_US: Month + zh_Hans: 月 + - name: day + type: number + required: true + form: llm + label: + en_US: day + zh_Hans: 日 + human_description: + en_US: day + zh_Hans: 日 From bbea3a6b8447b87f0fffef0d5c1387285feb128a Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 14 Mar 2024 17:51:11 +0800 Subject: [PATCH 033/200] fix: compatible to old tool config (#2837) --- api/controllers/console/app/app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index fb697b6418..ff97405415 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -245,6 +245,8 @@ class AppApi(Resource): agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue agent_tool_entity = AgentToolEntity(**tool) # get tool try: From 4d637701894ee5cc997a9ed5cad82c3835eabc30 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 14 Mar 2024 09:53:55 +0000 Subject: [PATCH 034/200] fix: The generate conversation name was not saved (#2836) --- .../generate_conversation_name_when_first_message_created.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 12cb325e45..74dc8d5112 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -26,5 +26,6 @@ def handle(sender, **kwargs): conversation.name = name except: pass - + + db.session.merge(conversation) db.session.commit() From af98954fc19b3f6eea4ebe19e8a0f99cdc38c2b0 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:03:59 +0800 Subject: [PATCH 035/200] Feat/add script to check i18n keys (#2835) --- web/i18n/script.js | 77 ++++++++++++++++++++++++++++++++++++++++++++++ web/package.json | 3 +- 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 web/i18n/script.js diff --git a/web/i18n/script.js b/web/i18n/script.js new file mode 100644 index 0000000000..e6f1853c83 --- /dev/null +++ b/web/i18n/script.js @@ -0,0 +1,77 @@ +/* eslint-disable no-eval */ +const fs = require('node:fs') +const path = require('node:path') +const transpile = require('typescript').transpile + +const targetLanguage = 'en-US' +const languages = ['zh-Hans', 'fr-FR', 'ja-JP', 'pt-BR', 'uk-UA', 'vi-VN'] + +async function getKeysFromLanuage(language) { + return new Promise((resolve, reject) => { + const folderPath = path.join(__dirname, language) + let allKeys = [] + fs.readdir(folderPath, (err, files) => { + if (err) { + console.error('Error reading folder:', err) + reject(err) + return + } + + files.forEach((file) => { + const filePath = path.join(folderPath, file) + const fileName = file.replace(/\.[^/.]+$/, '') // Remove file extension + const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => + c.toUpperCase(), + ) // Convert to camel case + // console.log(camelCaseFileName) + const content = fs.readFileSync(filePath, 'utf8') + const translation = eval(transpile(content)) + const keys = Object.keys(translation) + const nestedKeys = [] + const iterateKeys = (obj, prefix = '') => { + for (const key in obj) { + const nestedKey = prefix ? `${prefix}.${key}` : key + nestedKeys.push(nestedKey) + if (typeof obj[key] === 'object') + iterateKeys(obj[key], nestedKey) + } + } + iterateKeys(translation) + + allKeys = [...keys, ...nestedKeys].map( + key => `${camelCaseFileName}.${key}`, + ) + }) + resolve(allKeys) + }) + }) +} + +async function main() { + const compareKeysCount = async () => { + const targetKeys = await getKeysFromLanuage(targetLanguage) + const languagesKeys = await Promise.all(languages.map(language => getKeysFromLanuage(language))) + + const keysCount = languagesKeys.map(keys => keys.length) + const targetKeysCount = targetKeys.length + + const comparison = languages.reduce((result, language, index) => { + const languageKeysCount = keysCount[index] + const difference = targetKeysCount - languageKeysCount + result[language] = difference + return result + }, {}) + + console.log(comparison) + + // Print missing keys + languages.forEach((language, index) => { + const missingKeys = targetKeys.filter(key => !languagesKeys[index].includes(key)) + console.log(`Missing keys in ${language}:`, missingKeys) + }) + } + + compareKeysCount() +} + +main() diff --git a/web/package.json b/web/package.json index 1ef344645a..fc466f42b3 100644 --- a/web/package.json +++ b/web/package.json @@ -11,7 +11,8 @@ "eslint-fix": "eslint --fix", "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky install ./web/.husky", "gen-icons": "node ./app/components/base/icons/script.js", - "uglify-embed": "node ./bin/uglify-embed" + "uglify-embed": "node ./bin/uglify-embed", + "check-i18n": "node ./i18n/script.js" }, "dependencies": { "@babel/runtime": "^7.22.3", From 742be06ea9d8853534781460d91e79e3169ae0a6 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:41:51 +0800 Subject: [PATCH 036/200] Fix/localai (#2840) --- .../model_providers/localai/llm/llm.py | 14 +++++++--- .../model_providers/localai/localai.yaml | 9 ++++++ .../localai/text_embedding/text_embedding.py | 28 +++++++++++++++++-- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 3c805682f3..161e65302f 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,6 +1,5 @@ from collections.abc import Generator from typing import cast -from urllib.parse import urljoin from httpx import Timeout from openai import ( @@ -19,6 +18,7 @@ from openai import ( from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion +from yarl import URL from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel): UserPromptMessage(content='ping') ], model_parameters={ 'max_tokens': 10, - }, stop=[]) + }, stop=[], stream=False) except Exception as ex: raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') @@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel): ) ] + model_properties = { + ModelPropertyKey.MODE: completion_model, + } if completion_model else {} + + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + entity = AIModelEntity( model=model, label=I18nObject( @@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {}, + model_properties=model_properties, parameter_rules=rules ) @@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": urljoin(credentials['server_url'], 'v1'), + "base_url": str(URL(credentials['server_url']) / 'v1'), } return client_kwargs diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index e4b625d171..a870914632 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -56,3 +56,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 + - variable: context_size + label: + zh_Hans: 上下文大小 + en_US: Context size + placeholder: + zh_Hans: 输入上下文大小 + en_US: Enter context size + required: false + type: text-input diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index c95007d271..954c9d10f2 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -1,11 +1,12 @@ import time from json import JSONDecodeError, dumps -from os.path import join from typing import Optional from requests import post +from yarl import URL -from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -57,7 +58,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): } try: - response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) @@ -113,6 +114,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens + + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + Get customizable model schema + + :param model: model name + :param credentials: model credentials + :return: model schema + """ + return AIModelEntity( + model=model, + label=I18nObject(zh_Hans=model, en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + features=[], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[] + ) def validate_credentials(self, model: str, credentials: dict) -> None: """ From f29280ba5c4be690d34a797cd3a0f552fffcad4f Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:44:24 +0800 Subject: [PATCH 037/200] Fix/compatible to old tool config (#2839) --- api/controllers/console/app/model_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 436f8c1447..2095bb6bea 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -52,6 +52,9 @@ class ModelConfigResource(Resource): masked_parameter_map = {} tool_map = {} for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) # get tool try: From 156345cb4b16a7922eab03ca4a06a6cf6de99fc3 Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Fri, 15 Mar 2024 12:05:35 +0800 Subject: [PATCH 038/200] fix: use supported languages only for install form (#2844) --- web/app/activate/activateForm.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index be3706037b..eaaa86988c 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -171,8 +171,8 @@ const ActivateForm = () => {
item.supported)} onSelect={(item) => { setLanguage(item.value as string) }} From 8a4015722dc59812344b02d6dfbfebfcfb348073 Mon Sep 17 00:00:00 2001 From: Rozstone <42225395+wststone@users.noreply.github.com> Date: Fri, 15 Mar 2024 13:19:06 +0800 Subject: [PATCH 039/200] prevent auto scrolling down to bottom when user already scrolled up (#2813) --- web/app/components/base/chat/chat/index.tsx | 28 +++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 2f3ae50131..2e46f1e869 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -4,6 +4,7 @@ import type { } from 'react' import { memo, + useCallback, useEffect, useRef, } from 'react' @@ -76,19 +77,20 @@ const Chat: FC = ({ const chatContainerInnerRef = useRef(null) const chatFooterRef = useRef(null) const chatFooterInnerRef = useRef(null) + const userScrolledRef = useRef(false) - const handleScrolltoBottom = () => { - if (chatContainerRef.current) + const handleScrolltoBottom = useCallback(() => { + if (chatContainerRef.current && !userScrolledRef.current) chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight - } + }, []) - const handleWindowResize = () => { + const handleWindowResize = useCallback(() => { if (chatContainerRef.current && chatFooterRef.current) chatFooterRef.current.style.width = `${chatContainerRef.current.clientWidth}px` if (chatContainerInnerRef.current && chatFooterInnerRef.current) chatFooterInnerRef.current.style.width = `${chatContainerInnerRef.current.clientWidth}px` - } + }, []) useThrottleEffect(() => { handleScrolltoBottom() @@ -98,7 +100,7 @@ const Chat: FC = ({ useEffect(() => { window.addEventListener('resize', debounce(handleWindowResize)) return () => window.removeEventListener('resize', handleWindowResize) - }, []) + }, [handleWindowResize]) useEffect(() => { if (chatFooterRef.current && chatContainerRef.current) { @@ -117,7 +119,19 @@ const Chat: FC = ({ resizeObserver.disconnect() } } - }, [chatFooterRef, chatContainerRef]) + }, [handleScrolltoBottom]) + + useEffect(() => { + const chatContainer = chatContainerRef.current + if (chatContainer) { + const setUserScrolled = () => { + if (chatContainer) + userScrolledRef.current = chatContainer.scrollHeight - chatContainer.scrollTop >= chatContainer.clientHeight + 300 + } + chatContainer.addEventListener('scroll', setUserScrolled) + return () => chatContainer.removeEventListener('scroll', setUserScrolled) + } + }, []) const hasTryToAsk = config?.suggested_questions_after_answer?.enabled && !!suggestedQuestions?.length && onSend From cef16862da3b8f61a7bff235c3f86345c7811cab Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:02:52 +0800 Subject: [PATCH 040/200] fix: charts encoding (#2848) --- .../tools/provider/builtin/chart/chart.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 813b4abcf2..f5e42e766d 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,4 +1,6 @@ import matplotlib.pyplot as plt +from fontTools.ttLib import TTFont +from matplotlib.font_manager import findSystemFonts from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.chart.tools.line import LinearChartTool @@ -6,6 +8,37 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl # use a business theme plt.style.use('seaborn-v0_8-darkgrid') +plt.rcParams['axes.unicode_minus'] = False + +def init_fonts(): + fonts = findSystemFonts() + + popular_unicode_fonts = [ + 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', + 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', + 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' + ] + + supported_fonts = [] + + for font_path in fonts: + try: + font = TTFont(font_path) + # get family name + family_name = font['name'].getName(1, 3, 1).toUnicode() + if family_name in popular_unicode_fonts: + supported_fonts.append(family_name) + except: + pass + + plt.rcParams['font.family'] = 'sans-serif' + # sort by order of popular_unicode_fonts + for font in popular_unicode_fonts: + if font in supported_fonts: + plt.rcParams['font.sans-serif'] = font + break + +init_fonts() class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: From 381b3d5016d0f8672b81a0d91fbc5f544d2690d0 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 16:55:59 +0800 Subject: [PATCH 041/200] optimize get app model to wraps --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/__init__.py | 21 ---- api/controllers/console/app/app.py | 100 +++++++----------- api/controllers/console/app/audio.py | 23 ++-- api/controllers/console/app/completion.py | 36 ++----- api/controllers/console/app/conversation.py | 59 ++++------- api/controllers/console/app/message.py | 64 ++++------- api/controllers/console/app/model_config.py | 17 ++- api/controllers/console/app/site.py | 14 +-- api/controllers/console/app/statistic.py | 38 +++---- api/controllers/console/app/workflow.py | 20 ++++ api/controllers/console/app/wraps.py | 55 ++++++++++ api/core/app_runner/basic_app_runner.py | 4 +- api/core/entities/application_entities.py | 20 ++++ api/core/prompt/prompt_transform.py | 20 +--- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 2 +- 17 files changed, 232 insertions(+), 265 deletions(-) create mode 100644 api/controllers/console/app/workflow.py create mode 100644 api/controllers/console/app/wraps.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ecfdc38612..934b19116b 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ api = ExternalApi(bp) from . import admin, apikey, extension, feature, setup, version # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic) + model_config, site, statistic, workflow) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index b0b07517f1..e69de29bb2 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,21 +0,0 @@ -from controllers.console.app.error import AppUnavailableError -from extensions.ext_database import db -from flask_login import current_user -from models.model import App -from werkzeug.exceptions import NotFound - - -def _get_app(app_id, mode=None): - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() - - if not app: - raise NotFound("App not found") - - if mode and app.mode != mode: - raise NotFound("The {} app not found".format(mode)) - - return app diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ff97405415..c366ace93a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,7 +9,8 @@ from werkzeug.exceptions import Forbidden from constants.languages import demo_model_templates, languages from constants.model_template import model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -31,13 +32,6 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager from core.entities.application_entities import AgentToolEntity -def _get_app(app_id, tenant_id): - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() - if not app: - raise AppNotFoundError - return app - - class AppListApi(Resource): @setup_required @@ -234,14 +228,12 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields_with_site) - def get(self, app_id): + def get(self, app_model): """Get app detail""" - app_id = str(app_id) - app: App = _get_app(app_id, current_user.current_tenant_id) - # get original app model config - model_config: AppModelConfig = app.app_model_config + model_config: AppModelConfig = app_model.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get('tools') or []: @@ -277,27 +269,24 @@ class AppApi(Resource): # override agent mode model_config.agent_mode = json.dumps(agent_mode) - return app + return app_model @setup_required @login_required @account_initialization_required - def delete(self, app_id): + @get_app_model + def delete(self, app_model): """Delete app""" - app_id = str(app_id) - if not current_user.is_admin_or_owner: raise Forbidden() - app = _get_app(app_id, current_user.current_tenant_id) - - db.session.delete(app) + db.session.delete(app_model) db.session.commit() # todo delete related data?? # model_config, site, api_token, conversation, message, message_feedback, message_annotation - app_was_deleted.send(app) + app_was_deleted.send(app_model) return {'result': 'success'}, 204 @@ -306,86 +295,77 @@ class AppNameApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app.name = args.get('name') - app.updated_at = datetime.utcnow() + app_model.name = args.get('name') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppIconApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') - app.updated_at = datetime.utcnow() + app_model.icon = args.get('icon') + app_model.icon_background = args.get('icon_background') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppSiteStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id).first() - if not app: - raise AppNotFoundError - if args.get('enable_site') == app.enable_site: - return app + if args.get('enable_site') == app_model.enable_site: + return app_model - app.enable_site = args.get('enable_site') - app.updated_at = datetime.utcnow() + app_model.enable_site = args.get('enable_site') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppApiStatus(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) + if args.get('enable_api') == app_model.enable_api: + return app_model - if args.get('enable_api') == app.enable_api: - return app - - app.enable_api = args.get('enable_api') - app.updated_at = datetime.utcnow() + app_model.enable_api = args.get('enable_api') + app_model.updated_at = datetime.utcnow() db.session.commit() - return app + return app_model class AppCopy(Resource): @@ -415,16 +395,14 @@ class AppCopy(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_detail_fields) - def post(self, app_id): - app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) - - copy_app = self.create_app_copy(app) + def post(self, app_model): + copy_app = self.create_app_copy(app_model) db.session.add(copy_app) app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_id). \ + filter(AppModelConfig.app_id == app_model.id). \ one_or_none() if app_config: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 77eaf136fc..daa5570f9a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,7 +6,6 @@ from werkzeug.exceptions import InternalServerError import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -18,8 +17,10 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -36,10 +37,8 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): file = request.files['file'] try: @@ -80,10 +79,8 @@ class ChatMessageTextApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id, None) - + @get_app_model + def post(self, app_model): try: response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, @@ -120,9 +117,11 @@ class ChatMessageTextApi(Resource): class TextModesApi(Resource): - def get(self, app_id: str): - app_model = _get_app(str(app_id)) - + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): try: parser = reqparse.RequestParser() parser.add_argument('language', type=str, required=True, location='args') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f01d2afa03..f378f7b218 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -10,7 +10,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,10 +18,11 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -36,12 +36,8 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') @@ -93,12 +89,8 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'completion') - + @get_app_model(mode=AppMode.WORKFLOW) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) @@ -110,12 +102,8 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app_model = _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json') @@ -179,12 +167,8 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id, task_id): - app_id = str(app_id) - - # get app info - _get_app(app_id, 'chat') - + @get_app_model(mode=AppMode.CHAT) + def post(self, app_model, task_id): account = flask_login.current_user ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 452b0fddf6..4ee1ee4035 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -9,9 +9,10 @@ from sqlalchemy.orm import joinedload from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -29,10 +30,9 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -43,10 +43,7 @@ class CompletionConversationApi(Resource): parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'completion') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') if args['keyword']: query = query.join( @@ -106,24 +103,22 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(conversation_message_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'completion') + return _get_conversation(app_model, conversation_id) @setup_required @login_required @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -139,10 +134,9 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_with_summary_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -154,10 +148,7 @@ class ChatConversationApi(Resource): parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') args = parser.parse_args() - # get app info - app = _get_app(app_id, 'chat') - - query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'chat') if args['keyword']: query = query.join( @@ -228,25 +219,22 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model(mode=AppMode.CHAT) @marshal_with(conversation_detail_fields) - def get(self, app_id, conversation_id): - app_id = str(app_id) + def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - return _get_conversation(app_id, conversation_id, 'chat') + return _get_conversation(app_model, conversation_id) @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required - def delete(self, app_id, conversation_id): - app_id = str(app_id) + def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) - # get app info - app = _get_app(app_id, 'chat') - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") @@ -263,12 +251,9 @@ api.add_resource(ChatConversationApi, '/apps//chat-conversations') api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') -def _get_conversation(app_id, conversation_id, mode): - # get app info - app = _get_app(app_id, mode) - +def _get_conversation(app_model, conversation_id): conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 0064dbe663..360602b9c2 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -10,7 +10,6 @@ from flask_restful.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api -from controllers.console.app import _get_app from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -18,9 +17,10 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom +from core.entities.application_entities import InvokeFrom, AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -46,14 +46,10 @@ class ChatMessageListApi(Resource): @setup_required @login_required + @get_app_model(mode=AppMode.CHAT) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id, 'chat') - + def get(self, app_model): parser = reqparse.RequestParser() parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') @@ -62,7 +58,7 @@ class ChatMessageListApi(Resource): conversation = db.session.query(Conversation).filter( Conversation.id == args['conversation_id'], - Conversation.app_id == app.id + Conversation.app_id == app_model.id ).first() if not conversation: @@ -110,12 +106,8 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') @@ -125,7 +117,7 @@ class MessageFeedbackApi(Resource): message = db.session.query(Message).filter( Message.id == message_id, - Message.app_id == app.id + Message.app_id == app_model.id ).first() if not message: @@ -141,7 +133,7 @@ class MessageFeedbackApi(Resource): raise ValueError('rating cannot be None when feedback not exists') else: feedback = MessageFeedback( - app_id=app.id, + app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=args['rating'], @@ -160,21 +152,20 @@ class MessageAnnotationApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('annotation') + @get_app_model @marshal_with(annotation_fields) - def post(self, app_id): + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - app_id = str(app_id) - parser = reqparse.RequestParser() parser.add_argument('message_id', required=False, type=uuid_value, location='json') parser.add_argument('question', required=True, type=str, location='json') parser.add_argument('answer', required=True, type=str, location='json') parser.add_argument('annotation_reply', required=False, type=dict, location='json') args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) return annotation @@ -183,14 +174,10 @@ class MessageAnnotationCountApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - - # get app info - app = _get_app(app_id) - + @get_app_model + def get(self, app_model): count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app.id + MessageAnnotation.app_id == app_model.id ).count() return {'count': count} @@ -200,8 +187,8 @@ class MessageMoreLikeThisApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.COMPLETION) + def get(self, app_model, message_id): message_id = str(message_id) parser = reqparse.RequestParser() @@ -211,9 +198,6 @@ class MessageMoreLikeThisApi(Resource): streaming = args['response_mode'] == 'streaming' - # get app info - app_model = _get_app(app_id, 'completion') - try: response = CompletionService.generate_more_like_this( app_model=app_model, @@ -257,13 +241,10 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id, message_id): - app_id = str(app_id) + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id, 'chat') - try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, @@ -294,14 +275,11 @@ class MessageApi(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(message_detail_fields) - def get(self, app_id, message_id): - app_id = str(app_id) + def get(self, app_model, message_id): message_id = str(message_id) - # get app info - app_model = _get_app(app_id) - message = db.session.query(Message).filter( Message.id == message_id, Message.app_id == app_model.id diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 2095bb6bea..0f8bc28f6f 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -5,7 +5,7 @@ from flask_login import current_user from flask_restful import Resource from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.entities.application_entities import AgentToolEntity @@ -23,22 +23,19 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + @get_app_model + def post(self, app_model): """Modify app model config""" - app_id = str(app_id) - - app = _get_app(app_id) - # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, account=current_user, config=request.json, - app_mode=app.mode + app_mode=app_model.mode ) new_app_model_config = AppModelConfig( - app_id=app.id, + app_id=app_model.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) @@ -130,11 +127,11 @@ class ModelConfigResource(Resource): db.session.add(new_app_model_config) db.session.flush() - app.app_model_config_id = new_app_model_config.id + app_model.app_model_config_id = new_app_model_config.id db.session.commit() app_model_config_was_updated.send( - app, + app_model, app_model_config=new_app_model_config ) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 4e9d9ed9b4..256824981e 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db @@ -34,13 +34,11 @@ class AppSite(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): + def post(self, app_model): args = parse_app_site_args() - app_id = str(app_id) - app_model = _get_app(app_id) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() @@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource): @setup_required @login_required @account_initialization_required + @get_app_model @marshal_with(app_site_fields) - def post(self, app_id): - app_id = str(app_id) - app_model = _get_app(app_id) - + def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 7aed7da404..e3bc44d6e9 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -7,9 +7,10 @@ from flask_login import current_user from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app import _get_app +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required @@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.CHAT) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'chat') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model(mode=AppMode.WORKFLOW) + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id, 'completion') parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') @@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + @get_app_model + def get(self, app_model): account = current_user - app_id = str(app_id) - app_model = _get_app(app_id) parser = reqparse.RequestParser() parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py new file mode 100644 index 0000000000..5a08e31c16 --- /dev/null +++ b/api/controllers/console/app/workflow.py @@ -0,0 +1,20 @@ +from flask_restful import Resource + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AppMode +from libs.login import login_required + + +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW]) + def post(self, app_model): + return 'success', 200 + + +api.add_resource(DefaultBlockConfigApi, '/apps//default-workflow-block-configs') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py new file mode 100644 index 0000000000..b3aca51871 --- /dev/null +++ b/api/controllers/console/app/wraps.py @@ -0,0 +1,55 @@ +from functools import wraps +from typing import Union, Optional, Callable + +from controllers.console.app.error import AppNotFoundError +from core.entities.application_entities import AppMode +from extensions.ext_database import db +from libs.login import current_user +from models.model import App + + +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): + if not kwargs.get('app_id'): + raise ValueError('missing app_id in path parameters') + + app_id = kwargs.get('app_id') + app_id = str(app_id) + + del kwargs['app_id'] + + app_model = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + # [temp] if workflow is in the mode list, then completion should be in the mode list + if AppMode.WORKFLOW in modes: + modes.append(AppMode.COMPLETION) + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs['app_model'] = app_model + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d3c91337c8..d1e16f860c 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -4,12 +4,12 @@ from typing import Optional from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity +from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity, \ + AppMode from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException -from core.prompt.prompt_transform import AppMode from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index abcf605c92..d3231affb2 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,6 +9,26 @@ from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity +class AppMode(Enum): + COMPLETION = 'completion' # will be deprecated in the future + WORKFLOW = 'workflow' # instead of 'completion' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class ModelConfigEntity(BaseModel): """ Model Config Entity. diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 0a373b7c42..08d94661b7 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -7,7 +7,7 @@ from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, - PromptTemplateEntity, + PromptTemplateEntity, AppMode, ) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -25,24 +25,6 @@ from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser -class AppMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelMode(enum.Enum): COMPLETION = 'completion' CHAT = 'chat' diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d52f6e20c2..3cf58d8e09 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,6 +1,7 @@ import copy +from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -13,7 +14,6 @@ from core.prompt.advanced_prompt_templates import ( COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) -from core.prompt.prompt_transform import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2e21e56266..ccfb101405 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -2,11 +2,11 @@ import re import uuid from core.entities.agent_entities import PlanningStrategy +from core.entities.application_entities import AppMode from core.external_data_tool.factory import ExternalDataToolFactory from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory -from core.prompt.prompt_transform import AppMode from core.provider_manager import ProviderManager from models.account import Account from services.dataset_service import DatasetService From d430136f656606bf8c7bb1c3bed7492d4b901dfb Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 16:56:29 +0800 Subject: [PATCH 042/200] lint --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/wraps.py | 3 ++- api/core/app_runner/basic_app_runner.py | 9 +++++++-- api/core/prompt/prompt_transform.py | 3 ++- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index f378f7b218..381d0bbb6b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,7 +22,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom, AppMode +from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 360602b9c2..5d4f6b7e26 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -20,7 +20,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import InvokeFrom, AppMode +from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index b3aca51871..fe2b408702 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Union, Optional, Callable +from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from core.entities.application_entities import AppMode diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d1e16f860c..d87302c717 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -4,8 +4,13 @@ from typing import Optional from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity, \ - AppMode +from core.entities.application_entities import ( + ApplicationGenerateEntity, + AppMode, + DatasetEntity, + InvokeFrom, + ModelConfigEntity, +) from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 08d94661b7..4bf96ce265 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,8 +6,9 @@ from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, + AppMode, ModelConfigEntity, - PromptTemplateEntity, AppMode, + PromptTemplateEntity, ) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory From b7c6cba23f24625f41a5446abcec6e210354f04d Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 20:48:54 +0800 Subject: [PATCH 043/200] add workflow models --- api/controllers/console/app/workflow.py | 21 +- .../versions/b289e2408ee2_add_workflow.py | 143 +++++++++++ api/models/model.py | 20 +- api/models/workflow.py | 237 ++++++++++++++++++ 4 files changed, 415 insertions(+), 6 deletions(-) create mode 100644 api/migrations/versions/b289e2408ee2_add_workflow.py create mode 100644 api/models/workflow.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5a08e31c16..4acdb4943d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -12,9 +12,20 @@ class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW]) - def post(self, app_model): - return 'success', 200 + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('app_mode', type=str, required=True, nullable=False, + choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args') + args = parser.parse_args() + + app_mode = args.get('app_mode') + app_mode = AppMode.value_of(app_mode) + + # TODO: implement this + + return { + "blocks": [] + } -api.add_resource(DefaultBlockConfigApi, '/apps//default-workflow-block-configs') +api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs') diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py new file mode 100644 index 0000000000..52168a04e7 --- /dev/null +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -0,0 +1,143 @@ +"""add workflow + +Revision ID: b289e2408ee2 +Revises: 16830a790f0f +Create Date: 2024-02-19 12:47:24.646954 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b289e2408ee2' +down_revision = '16830a790f0f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_app_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_app_log_pkey') + ) + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.create_index('workflow_app_log_app_idx', ['tenant_id', 'app_id'], unique=False) + + op.create_table('workflow_node_executions', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('workflow_run_id', postgresql.UUID(), nullable=True), + sa.Column('index', sa.Integer(), nullable=False), + sa.Column('predecessor_node_id', sa.String(length=255), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=False), + sa.Column('node_type', sa.String(length=255), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('inputs', sa.Text(), nullable=False), + sa.Column('process_data', sa.Text(), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('execution_metadata', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') + ) + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.create_index('workflow_node_execution_node_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'], unique=False) + batch_op.create_index('workflow_node_execution_workflow_run_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'workflow_run_id'], unique=False) + + op.create_table('workflow_runs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('sequence_number', sa.Integer(), nullable=False), + sa.Column('workflow_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('triggered_from', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('outputs', sa.Text(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), + sa.Column('currency', sa.String(length=255), nullable=True), + sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('finished_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') + ) + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from'], unique=False) + + op.create_table('workflows', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('graph', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_pkey') + ) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_run_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('workflow_run_id') + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('workflow_id') + batch_op.drop_column('chatbot_app_engine') + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_index('workflow_version_idx') + + op.drop_table('workflows') + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_triggerd_from_idx') + + op.drop_table('workflow_runs') + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_workflow_run_idx') + batch_op.drop_index('workflow_node_execution_node_run_idx') + + op.drop_table('workflow_node_executions') + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_app_log_app_idx') + + op.drop_table('workflow_app_logs') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 8776f89673..6e7a58ed45 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -12,6 +12,7 @@ from extensions.ext_database import db from libs.helper import generate_string from .account import Account, Tenant +from .workflow import WorkflowRun, Workflow class DifySetup(db.Model): @@ -156,12 +157,14 @@ class AppModelConfig(db.Model): agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, default='simple') + prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) + chatbot_app_engine = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + workflow_id = db.Column(UUID) @property def app(self): @@ -261,6 +264,13 @@ class AppModelConfig(db.Model): "image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} + @property + def workflow(self): + if self.workflow_id: + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + def to_dict(self) -> dict: return { "provider": "", @@ -581,6 +591,7 @@ class Message(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + workflow_run_id = db.Column(UUID) @property def user_feedback(self): @@ -679,6 +690,13 @@ class Message(db.Model): return files + @property + def workflow_run(self): + if self.workflow_run_id: + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + return None + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' diff --git a/api/models/workflow.py b/api/models/workflow.py new file mode 100644 index 0000000000..59b8eeb6cd --- /dev/null +++ b/api/models/workflow.py @@ -0,0 +1,237 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + + +class Workflow(db.Model): + """ + Workflow, for `Workflow App` and `Chat App workflow mode`. + + Attributes: + + - id (uuid) Workflow ID, pk + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - type (string) Workflow type + + `workflow` for `Workflow App` + + `chat` for `Chat App workflow mode` + + - version (string) Version + + `draft` for draft version (only one for each app), other for version number (redundant) + + - graph (text) Workflow canvas configuration (JSON) + + The entire canvas configuration JSON, including Node, Edge, and other configurations + + - nodes (array[object]) Node list, see Node Schema + + - edges (array[object]) Edge list, see Edge Schema + + - created_by (uuid) Creator ID + - created_at (timestamp) Creation time + - updated_by (uuid) `optional` Last updater ID + - updated_at (timestamp) `optional` Last update time + """ + + __tablename__ = 'workflows' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_pkey'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(UUID) + updated_at = db.Column(db.DateTime) + + +class WorkflowRun(db.Model): + """ + Workflow Run + + Attributes: + + - id (uuid) Run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID + - type (string) Workflow type + - triggered_from (string) Trigger source + + `debugging` for canvas debugging + + `app-run` for (published) app execution + + - version (string) Version + - graph (text) Workflow canvas configuration (JSON) + - inputs (text) Input parameters + - status (string) Execution status, `running` / `succeeded` / `failed` + - outputs (text) `optional` Output content + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - total_tokens (int) `optional` Total tokens used + - total_price (decimal) `optional` Total cost + - currency (string) `optional` Currency, such as USD / RMB + - total_steps (int) Total steps (redundant), default 0 + - created_by (uuid) Runner ID + - created_at (timestamp) Run time + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_runs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), + db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + sequence_number = db.Column(db.Integer, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + version = db.Column(db.String(255), nullable=False) + graph = db.Column(db.Text) + inputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + outputs = db.Column(db.Text) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255)) + total_steps = db.Column(db.Integer, server_default=db.text('0')) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + finished_at = db.Column(db.DateTime) + + +class WorkflowNodeExecution(db.Model): + """ + Workflow Node Execution + + - id (uuid) Execution ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Workflow ID + - triggered_from (string) Trigger source + + `single-step` for single-step debugging + + `workflow-run` for workflow execution (debugging / user execution) + + - workflow_run_id (uuid) `optional` Workflow run ID + + Null for single-step debugging. + + - index (int) Execution sequence number, used for displaying Tracing Node order + - predecessor_node_id (string) `optional` Predecessor node ID, used for displaying execution path + - node_id (string) Node ID + - node_type (string) Node type, such as `start` + - title (string) Node title + - inputs (json) All predecessor node variable content used in the node + - process_data (json) Node process data + - outputs (json) `optional` Node output variables + - status (string) Execution status, `running` / `succeeded` / `failed` + - error (string) `optional` Error reason + - elapsed_time (float) `optional` Time consumption (s) + - execution_metadata (text) Metadata + + - total_tokens (int) `optional` Total tokens used + + - total_price (decimal) `optional` Total cost + + - currency (string) `optional` Currency, such as USD / RMB + + - created_at (timestamp) Run time + - created_by (uuid) Runner ID + - finished_at (timestamp) End time + """ + + __tablename__ = 'workflow_node_executions' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), + db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'workflow_run_id'), + db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + triggered_from = db.Column(db.String(255), nullable=False) + workflow_run_id = db.Column(UUID) + index = db.Column(db.Integer, nullable=False) + predecessor_node_id = db.Column(db.String(255)) + node_id = db.Column(db.String(255), nullable=False) + node_type = db.Column(db.String(255), nullable=False) + title = db.Column(db.String(255), nullable=False) + inputs = db.Column(db.Text, nullable=False) + process_data = db.Column(db.Text, nullable=False) + outputs = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) + error = db.Column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + execution_metadata = db.Column(db.Text) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by = db.Column(UUID, nullable=False) + finished_at = db.Column(db.DateTime) + + +class WorkflowAppLog(db.Model): + """ + Workflow App execution log, excluding workflow debugging records. + + Attributes: + + - id (uuid) run ID + - tenant_id (uuid) Workspace ID + - app_id (uuid) App ID + - workflow_id (uuid) Associated Workflow ID + - workflow_run_id (uuid) Associated Workflow Run ID + - created_from (string) Creation source + + `service-api` App Execution OpenAPI + + `web-app` WebApp + + `installed-app` Installed App + + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + + - created_by (uuid) Creator ID, depends on the user table according to created_by_role + - created_at (timestamp) Creation time + """ + + __tablename__ = 'workflow_app_logs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), + db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(UUID, nullable=False) + workflow_run_id = db.Column(UUID, nullable=False) + created_from = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) From 603b1e9ed49c7b4b43033b43dcb76db1ebe5d476 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 19 Feb 2024 20:49:13 +0800 Subject: [PATCH 044/200] lint --- api/controllers/console/app/workflow.py | 1 - api/migrations/versions/b289e2408ee2_add_workflow.py | 2 +- api/models/model.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4acdb4943d..5689c0fd92 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,7 +1,6 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.entities.application_entities import AppMode diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 52168a04e7..605c66bed1 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -5,8 +5,8 @@ Revises: 16830a790f0f Create Date: 2024-02-19 12:47:24.646954 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/models/model.py b/api/models/model.py index 6e7a58ed45..2b44957b06 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -12,7 +12,7 @@ from extensions.ext_database import db from libs.helper import generate_string from .account import Account, Tenant -from .workflow import WorkflowRun, Workflow +from .workflow import Workflow, WorkflowRun class DifySetup(db.Model): From 3642dd3a7395a9c7b3a2ad3858bd89d6d089b772 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 20 Feb 2024 21:30:43 +0800 Subject: [PATCH 045/200] add workflow logics --- api/constants/model_template.py | 91 ++++-- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/app.py | 50 ++-- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 3 +- api/controllers/console/app/conversation.py | 3 +- api/controllers/console/app/error.py | 6 + api/controllers/console/app/message.py | 50 +--- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/app/workflow.py | 94 +++++-- api/controllers/console/app/wraps.py | 21 +- api/controllers/console/explore/message.py | 47 ---- api/controllers/console/ping.py | 17 ++ api/controllers/console/workspace/account.py | 15 +- api/controllers/console/workspace/members.py | 21 +- api/controllers/web/message.py | 47 ---- api/core/app_runner/basic_app_runner.py | 4 +- api/core/application_manager.py | 34 ++- api/core/entities/application_entities.py | 55 ++-- api/core/prompt/prompt_transform.py | 2 +- api/core/workflow/__init__.py | 0 api/core/workflow/entities/NodeEntities.py | 32 +++ api/core/workflow/entities/__init__.py | 0 api/core/workflow/nodes/__init__.py | 0 api/core/workflow/nodes/end/__init__.py | 0 api/core/workflow/nodes/end/end_node.py | 0 api/core/workflow/nodes/end/entities.py | 25 ++ api/core/workflow/workflow_engine_manager.py | 0 api/fields/annotation_fields.py | 8 +- api/fields/conversation_fields.py | 13 +- api/fields/member_fields.py | 38 +++ api/fields/workflow_fields.py | 16 ++ .../versions/b289e2408ee2_add_workflow.py | 2 +- api/models/model.py | 29 +- api/models/workflow.py | 55 +++- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 19 +- api/services/completion_service.py | 60 +--- api/services/errors/__init__.py | 2 +- api/services/errors/app.py | 2 - api/services/workflow/__init__.py | 0 api/services/workflow/defaults.py | 72 +++++ api/services/workflow/workflow_converter.py | 259 ++++++++++++++++++ api/services/workflow_service.py | 83 ++++++ 44 files changed, 894 insertions(+), 389 deletions(-) create mode 100644 api/controllers/console/ping.py create mode 100644 api/core/workflow/__init__.py create mode 100644 api/core/workflow/entities/NodeEntities.py create mode 100644 api/core/workflow/entities/__init__.py create mode 100644 api/core/workflow/nodes/__init__.py create mode 100644 api/core/workflow/nodes/end/__init__.py create mode 100644 api/core/workflow/nodes/end/end_node.py create mode 100644 api/core/workflow/nodes/end/entities.py create mode 100644 api/core/workflow/workflow_engine_manager.py create mode 100644 api/fields/member_fields.py create mode 100644 api/fields/workflow_fields.py delete mode 100644 api/services/errors/app.py create mode 100644 api/services/workflow/__init__.py create mode 100644 api/services/workflow/defaults.py create mode 100644 api/services/workflow/workflow_converter.py create mode 100644 api/services/workflow_service.py diff --git a/api/constants/model_template.py b/api/constants/model_template.py index d87f7c3926..c22306ac87 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,10 +1,10 @@ import json model_templates = { - # completion default mode - 'completion_default': { + # workflow default mode + 'workflow_default': { 'app': { - 'mode': 'completion', + 'mode': 'workflow', 'enable_site': True, 'enable_api': True, 'is_demo': False, @@ -15,24 +15,7 @@ model_templates = { 'model_config': { 'provider': '', 'model_id': '', - 'configs': {}, - 'model': json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": {} - }), - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + 'configs': {} } }, @@ -48,14 +31,70 @@ model_templates = { 'status': 'normal' }, 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {}, + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, 'model': json.dumps({ "provider": "openai", - "name": "gpt-3.5-turbo", + "name": "gpt-4", "mode": "chat", - "completion_params": {} + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + }) + } + }, + + # agent default mode + 'agent_default': { + 'app': { + 'mode': 'agent', + 'enable_site': True, + 'enable_api': True, + 'is_demo': False, + 'api_rpm': 0, + 'api_rph': 0, + 'status': 'normal' + }, + 'model_config': { + 'provider': 'openai', + 'model_id': 'gpt-4', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, + 'model': json.dumps({ + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } }) } }, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 934b19116b..649df278ec 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -5,7 +5,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api') api = ExternalApi(bp) # Import other controllers -from . import admin, apikey, extension, feature, setup, version +from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, model_config, site, statistic, workflow) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c366ace93a..cf505bedb8 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,7 +26,7 @@ from fields.app_fields import ( template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site +from models.model import App, AppModelConfig, Site, AppMode from services.app_model_config_service import AppModelConfigService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager @@ -80,7 +80,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json') + parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') @@ -90,18 +90,7 @@ class AppListApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - try: - provider_manager = ProviderManager() - default_model_entity = provider_manager.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except (ProviderTokenNotInitError, LLMBadRequestError): - default_model_entity = None - except Exception as e: - logging.exception(e) - default_model_entity = None - + # TODO: MOVE TO IMPORT API if args['model_config'] is not None: # validate config model_config_dict = args['model_config'] @@ -150,27 +139,30 @@ class AppListApi(Resource): if 'mode' not in args or args['mode'] is None: abort(400, message="mode is required") - model_config_template = model_templates[args['mode'] + '_default'] + app_mode = AppMode.value_of(args['mode']) + + model_config_template = model_templates[app_mode.value + '_default'] app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) - # get model provider - model_manager = ModelManager() + if app_mode in [AppMode.CHAT, AppMode.AGENT]: + # get model provider + model_manager = ModelManager() - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None - if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) + if model_instance: + model_dict = app_model_config.model_dict + model_dict['provider'] = model_instance.provider + model_dict['name'] = model_instance.model + app_model_config.model = json.dumps(model_dict) app.name = args['name'] app.mode = args['mode'] diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index daa5570f9a..458fa5098f 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -20,10 +20,10 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required +from models.model import AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 381d0bbb6b..11fdba177d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,11 +22,12 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import AppMode, InvokeFrom +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from libs.login import login_required +from models.model import AppMode from services.completion_service import CompletionService diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 4ee1ee4035..5d312149f7 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -12,7 +12,6 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from fields.conversation_fields import ( conversation_detail_fields, @@ -22,7 +21,7 @@ from fields.conversation_fields import ( ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation +from models.model import Conversation, Message, MessageAnnotation, AppMode class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index d7b31906c8..b1abb38248 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException): error_code = 'too_many_files' description = "Only one file is allowed." code = 400 + + +class DraftWorkflowNotExist(BaseHTTPException): + error_code = 'draft_workflow_not_exist' + description = "Draft workflow need to be initialized." + code = 400 diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5d4f6b7e26..9a177116ea 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -11,7 +11,6 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -20,7 +19,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.entities.application_entities import AppMode, InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -28,10 +26,8 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode from services.annotation_service import AppAnnotationService -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService @@ -183,49 +179,6 @@ class MessageAnnotationCountApi(Resource): return {'count': count} -class MessageMoreLikeThisApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=AppMode.COMPLETION) - def get(self, app_model, message_id): - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], - location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -291,7 +244,6 @@ class MessageApi(Resource): return message -api.add_resource(MessageMoreLikeThisApi, '/apps//completion-messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') api.add_resource(MessageFeedbackApi, '/apps//feedbacks') diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index e3bc44d6e9..ea4d597112 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -10,10 +10,10 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.helper import datetime_string from libs.login import login_required +from models.model import AppMode class DailyConversationStatistic(Resource): diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5689c0fd92..2794735bbb 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,30 +1,88 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse, marshal_with from controllers.console import api +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AppMode -from libs.login import login_required +from fields.workflow_fields import workflow_fields +from libs.login import login_required, current_user +from models.model import App, ChatbotAppEngine, AppMode +from services.workflow_service import WorkflowService + + +class DraftWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model=app_model) + + if not workflow: + raise DraftWorkflowNotExist() + + # return workflow, if not found, return None (initiate graph by frontend) + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def post(self, app_model: App): + """ + Sync draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) + + return { + "result": "success" + } class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, nullable=False, - choices=[AppMode.CHAT.value, AppMode.WORKFLOW.value], location='args') - args = parser.parse_args() - - app_mode = args.get('app_mode') - app_mode = AppMode.value_of(app_mode) - - # TODO: implement this - - return { - "blocks": [] - } + @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + def get(self, app_model: App): + """ + Get default block config + """ + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_configs() -api.add_resource(DefaultBlockConfigApi, '/default-workflow-block-configs') +class ConvertToWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.CHAT) + @marshal_with(workflow_fields) + def post(self, app_model: App): + """ + Convert basic mode of chatbot app to workflow + """ + # convert to workflow mode + workflow_service = WorkflowService() + workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model) + + # return workflow + return workflow + + +api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe2b408702..fe35e72304 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -3,13 +3,14 @@ from functools import wraps from typing import Optional, Union from controllers.console.app.error import AppNotFoundError -from core.entities.application_entities import AppMode from extensions.ext_database import db from libs.login import current_user -from models.model import App +from models.model import App, ChatbotAppEngine, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, + mode: Union[AppMode, list[AppMode]] = None, + app_engine: ChatbotAppEngine = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -37,14 +38,20 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ else: modes = [mode] - # [temp] if workflow is in the mode list, then completion should be in the mode list - if AppMode.WORKFLOW in modes: - modes.append(AppMode.COMPLETION) - if app_mode not in modes: mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + if app_engine is not None: + if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]: + raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.") + + if app_mode == AppMode.CHAT: + # fetch current app model config + app_model_config = app_model.app_model_config + if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value: + raise AppNotFoundError(f"{app_engine.value} app engine is not supported.") + kwargs['app_model'] = app_model return view_func(*args, **kwargs) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425f..bef26b4d99 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,7 +12,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api from controllers.console.app.error import ( - AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -24,13 +23,10 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -76,48 +72,6 @@ class MessageFeedbackApi(InstalledAppResource): return {'result': 'success'} -class MessageMoreLikeThisApi(InstalledAppResource): - def get(self, installed_app, message_id): - app_model = installed_app.app - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming - ) - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -166,5 +120,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py new file mode 100644 index 0000000000..7664ba8c16 --- /dev/null +++ b/api/controllers/console/ping.py @@ -0,0 +1,17 @@ +from flask_restful import Resource + +from controllers.console import api + + +class PingApi(Resource): + + def get(self): + """ + For connection health check + """ + return { + "result": "pong" + } + + +api.add_resource(PingApi, '/ping') diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index b7cfba9d04..656a4d4cee 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -16,26 +16,13 @@ from controllers.console.workspace.error import ( ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db +from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField -} - class AccountInitApi(Resource): diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf57cd4b24..f40ccebf25 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,33 +1,18 @@ from flask import current_app from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse +from flask_restful import Resource, abort, marshal_with, reqparse import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from libs.helper import TimestampField +from fields.member_fields import account_with_role_list_fields from libs.login import login_required from models.account import Account from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, -} - -account_list_fields = { - 'accounts': fields.List(fields.Nested(account_fields)) -} - class MemberListApi(Resource): """List all members of current tenant.""" @@ -35,7 +20,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_list_fields) + @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) return {'result': 'success', 'accounts': members}, 200 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb..5120f49c5e 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,7 +11,6 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api from controllers.web.error import ( - AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -21,14 +20,11 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value -from services.completion_service import CompletionService -from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -113,48 +109,6 @@ class MessageFeedbackApi(WebApiResource): return {'result': 'success'} -class MessageMoreLikeThisApi(WebApiResource): - def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': - raise NotCompletionAppError() - - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') - args = parser.parse_args() - - streaming = args['response_mode'] == 'streaming' - - try: - response = CompletionService.generate_more_like_this( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming - ) - - return compact_response(response) - except MessageNotExistsError: - raise NotFound("Message Not Exists.") - except MoreLikeThisDisabledError: - raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -202,5 +156,4 @@ class MessageSuggestedQuestionApi(WebApiResource): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d87302c717..26e9cc84aa 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -6,7 +6,6 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - AppMode, DatasetEntity, InvokeFrom, ModelConfigEntity, @@ -16,7 +15,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.model import App, Conversation, Message, AppMode logger = logging.getLogger(__name__) @@ -250,6 +249,7 @@ class BasicApplicationRunner(AppRunner): invoke_from ) + # TODO if (app_record.mode == AppMode.COMPLETION.value and dataset_config and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 9aca61c7bb..2fde422d47 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,7 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, + TextToSpeechEntity, VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -93,7 +93,7 @@ class ApplicationManager: app_id=app_id, app_model_config_id=app_model_config_id, app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self._convert_from_app_model_config_dict( + app_orchestration_config_entity=self.convert_from_app_model_config_dict( tenant_id=tenant_id, app_model_config_dict=app_model_config_dict ), @@ -234,7 +234,7 @@ class ApplicationManager: logger.exception(e) raise e - def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ + def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: """ Convert app model config dict to entity. @@ -384,8 +384,10 @@ class ApplicationManager: config=external_data_tool['config'] ) ) + + properties['variables'] = [] - # current external_data_tools + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] if typ == 'external_data_tool': @@ -397,6 +399,30 @@ class ApplicationManager: config=val['config'] ) ) + elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.TEXT_INPUT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + properties['variables'].append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) # show retrieve source show_retrieve_source = False diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index d3231affb2..092591a73f 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -9,26 +9,6 @@ from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity -class AppMode(Enum): - COMPLETION = 'completion' # will be deprecated in the future - WORKFLOW = 'workflow' # instead of 'completion' - CHAT = 'chat' - AGENT = 'agent' - - @classmethod - def value_of(cls, value: str) -> 'AppMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') - - class ModelConfigEntity(BaseModel): """ Model Config Entity. @@ -106,6 +86,38 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntity(BaseModel): + """ + Variable Entity. + """ + class Type(Enum): + TEXT_INPUT = 'text-input' + SELECT = 'select' + PARAGRAPH = 'paragraph' + + @classmethod + def value_of(cls, value: str) -> 'VariableEntity.Type': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid variable type value {value}') + + variable: str + label: str + description: Optional[str] = None + type: Type + required: bool = False + max_length: Optional[int] = None + options: Optional[list[str]] = None + default: Optional[str] = None + + class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. @@ -245,6 +257,7 @@ class AppOrchestrationConfigEntity(BaseModel): """ model_config: ModelConfigEntity prompt_template: PromptTemplateEntity + variables: list[VariableEntity] = [] external_data_variables: list[ExternalDataVariableEntity] = [] agent: Optional[AgentEntity] = None @@ -256,7 +269,7 @@ class AppOrchestrationConfigEntity(BaseModel): show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: dict = {} + text_to_speech: Optional[TextToSpeechEntity] = None sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4bf96ce265..abbfa96249 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -6,7 +6,6 @@ from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, - AppMode, ModelConfigEntity, PromptTemplateEntity, ) @@ -24,6 +23,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser +from models.model import AppMode class ModelMode(enum.Enum): diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/NodeEntities.py new file mode 100644 index 0000000000..d72b000dfb --- /dev/null +++ b/api/core/workflow/entities/NodeEntities.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class NodeType(Enum): + """ + Node Types. + """ + START = 'start' + END = 'end' + DIRECT_ANSWER = 'direct-answer' + LLM = 'llm' + KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' + IF_ELSE = 'if-else' + CODE = 'code' + TEMPLATE_TRANSFORM = 'template-transform' + QUESTION_CLASSIFIER = 'question-classifier' + HTTP_REQUEST = 'http-request' + TOOL = 'tool' + VARIABLE_ASSIGNER = 'variable-assigner' + + @classmethod + def value_of(cls, value: str) -> 'BlockType': + """ + Get value of given block type. + + :param value: block type value + :return: block type + """ + for block_type in cls: + if block_type.value == value: + return block_type + raise ValueError(f'invalid block type value {value}') diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py new file mode 100644 index 0000000000..045e7effc4 --- /dev/null +++ b/api/core/workflow/nodes/end/entities.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class EndNodeOutputType(Enum): + """ + END Node Output Types. + + none, plain-text, structured + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 5974de34de..d9cd6c03bb 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,12 +2,6 @@ from flask_restful import fields from libs.helper import TimestampField -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - annotation_fields = { "id": fields.String, @@ -15,7 +9,7 @@ annotation_fields = { "answer": fields.Raw(attribute='content'), "hit_count": fields.Integer, "created_at": TimestampField, - # 'account': fields.Nested(account_fields, allow_null=True) + # 'account': fields.Nested(simple_account_fields, allow_null=True) } annotation_list_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1adc836aa2..afa486f1cd 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,5 +1,6 @@ from flask_restful import fields +from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -8,31 +9,25 @@ class MessageTextField(fields.Raw): return value[0]['text'] if value else '' -account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} - feedback_fields = { 'rating': fields.String, 'content': fields.String, 'from_source': fields.String, 'from_end_user_id': fields.String, - 'from_account': fields.Nested(account_fields, allow_null=True), + 'from_account': fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { 'id': fields.String, 'question': fields.String, 'content': fields.String, - 'account': fields.Nested(account_fields, allow_null=True), + 'account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } annotation_hit_history_fields = { 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(account_fields, allow_null=True), + 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), 'created_at': TimestampField } diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py new file mode 100644 index 0000000000..79164b3848 --- /dev/null +++ b/api/fields/member_fields.py @@ -0,0 +1,38 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +simple_account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'is_password_set': fields.Boolean, + 'interface_language': fields.String, + 'interface_theme': fields.String, + 'timezone': fields.String, + 'last_login_at': TimestampField, + 'last_login_ip': fields.String, + 'created_at': TimestampField +} + +account_with_role_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'last_login_at': TimestampField, + 'created_at': TimestampField, + 'role': fields.String, + 'status': fields.String, +} + +account_with_role_list_fields = { + 'accounts': fields.List(fields.Nested(account_with_role_fields)) +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py new file mode 100644 index 0000000000..9dc92ea43b --- /dev/null +++ b/api/fields/workflow_fields.py @@ -0,0 +1,16 @@ +import json + +from flask_restful import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + + +workflow_fields = { + 'id': fields.String, + 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), + 'created_at': TimestampField, + 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), + 'updated_at': TimestampField +} diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 605c66bed1..e9cd2caf3a 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -102,7 +102,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_pkey') ) with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'type', 'version'], unique=False) + batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) diff --git a/api/models/model.py b/api/models/model.py index 2b44957b06..58e29cd21c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,5 +1,7 @@ import json import uuid +from enum import Enum +from typing import Optional from flask import current_app, request from flask_login import UserMixin @@ -25,6 +27,25 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) +class AppMode(Enum): + WORKFLOW = 'workflow' + CHAT = 'chat' + AGENT = 'agent' + + @classmethod + def value_of(cls, value: str) -> 'AppMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -56,7 +77,7 @@ class App(db.Model): return site @property - def app_model_config(self): + def app_model_config(self) -> Optional['AppModelConfig']: app_model_config = db.session.query(AppModelConfig).filter( AppModelConfig.id == self.app_model_config_id).first() return app_model_config @@ -130,6 +151,12 @@ class App(db.Model): return deleted_tools + +class ChatbotAppEngine(Enum): + NORMAL = 'normal' + WORKFLOW = 'workflow' + + class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( diff --git a/api/models/workflow.py b/api/models/workflow.py index 59b8eeb6cd..ed26e98896 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,43 @@ +from enum import Enum +from typing import Union + from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db +from models.account import Account +from models.model import AppMode + + +class WorkflowType(Enum): + """ + Workflow Type Enum + """ + WORKFLOW = 'workflow' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow type value {value}') + + @classmethod + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + """ + Get workflow type from app mode. + + :param app_mode: app mode + :return: workflow type + """ + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) + return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT class Workflow(db.Model): @@ -39,7 +76,7 @@ class Workflow(db.Model): __tablename__ = 'workflows' __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'type', 'version'), + db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) @@ -53,6 +90,14 @@ class Workflow(db.Model): updated_by = db.Column(UUID) updated_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowRun(db.Model): """ @@ -116,6 +161,14 @@ class WorkflowRun(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def updated_by_account(self): + return Account.query.get(self.updated_by) + class WorkflowNodeExecution(db.Model): """ diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 3cf58d8e09..1e893e0eca 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,6 @@ import copy -from core.entities.application_entities import AppMode from core.prompt.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, @@ -14,6 +13,7 @@ from core.prompt.advanced_prompt_templates import ( COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT, ) +from models.model import AppMode class AdvancedPromptTemplateService: diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index ccfb101405..3ac11c645c 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -9,6 +9,7 @@ from core.model_runtime.model_providers import model_provider_factory from core.moderation.factory import ModerationFactory from core.provider_manager import ProviderManager from models.account import Account +from models.model import AppMode from services.dataset_service import DatasetService SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -315,9 +316,6 @@ class AppModelConfigService: if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -443,21 +441,6 @@ class AppModelConfigService: config=config ) - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") - @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef41..5599c60113 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,12 +8,10 @@ from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError class CompletionService: @@ -157,62 +155,6 @@ class CompletionService: } ) - @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) - - application_manager = ApplicationManager() - return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, - user=user, - invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } - ) - @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 5804f599fe..a44c190cbc 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'app', 'completion', 'audio', 'file' + 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py deleted file mode 100644 index 7c4ca99c2a..0000000000 --- a/api/services/errors/app.py +++ /dev/null @@ -1,2 +0,0 @@ -class MoreLikeThisDisabledError(Exception): - pass diff --git a/api/services/workflow/__init__.py b/api/services/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/workflow/defaults.py b/api/services/workflow/defaults.py new file mode 100644 index 0000000000..67804fa4eb --- /dev/null +++ b/api/services/workflow/defaults.py @@ -0,0 +1,72 @@ +# default block config +default_block_configs = [ + { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + }, + { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " + "+ arg2\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + }, + { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + }, + { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } +] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py new file mode 100644 index 0000000000..c2fad83aaf --- /dev/null +++ b/api/services/workflow/workflow_converter.py @@ -0,0 +1,259 @@ +import json +from typing import Optional + +from core.application_manager import ApplicationManager +from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ + ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.model_runtime.utils import helper +from core.workflow.entities.NodeEntities import NodeType +from core.workflow.nodes.end.entities import EndNodeOutputType +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType + + +class WorkflowConverter: + """ + App Convert to Workflow Mode + """ + + def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + """ + Convert to workflow mode + + - basic mode of chatbot app + + - advanced mode of assistant app (for migration) + + - completion app (for migration) + + :param app_model: App instance + :param account: Account instance + :return: workflow instance + """ + # get original app config + app_model_config = app_model.app_model_config + + # convert app model config + application_manager = ApplicationManager() + application_manager.convert_from_app_model_config_dict( + tenant_id=app_model.tenant_id, + app_model_config_dict=app_model_config.to_dict() + ) + + # init workflow graph + graph = { + "nodes": [], + "edges": [] + } + + # Convert list: + # - variables -> start + # - model_config -> llm + # - prompt_template -> llm + # - file_upload -> llm + # - external_data_variables -> http-request + # - dataset -> knowledge-retrieval + # - show_retrieve_source -> knowledge-retrieval + + # convert to start node + start_node = self._convert_to_start_node( + variables=app_model_config.variables + ) + + graph['nodes'].append(start_node) + + # convert to http request node + if app_model_config.external_data_variables: + http_request_node = self._convert_to_http_request_node( + external_data_variables=app_model_config.external_data_variables + ) + + graph = self._append_node(graph, http_request_node) + + # convert to knowledge retrieval node + if app_model_config.dataset: + knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( + dataset=app_model_config.dataset, + show_retrieve_source=app_model_config.show_retrieve_source + ) + + graph = self._append_node(graph, knowledge_retrieval_node) + + # convert to llm node + llm_node = self._convert_to_llm_node( + model_config=app_model_config.model_config, + prompt_template=app_model_config.prompt_template, + file_upload=app_model_config.file_upload + ) + + graph = self._append_node(graph, llm_node) + + # convert to end node by app mode + end_node = self._convert_to_end_node(app_model=app_model) + + graph = self._append_node(graph, end_node) + + # get new app mode + app_mode = self._get_new_app_mode(app_model) + + # create workflow record + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + + db.session.add(workflow) + db.session.flush() + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ + if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.commit() + + return workflow + + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + """ + Convert to Start Node + :param variables: list of variables + :return: + """ + return { + "id": "start", + "position": None, + "data": { + "title": "START", + "type": NodeType.START.value, + "variables": [helper.dump_model(v) for v in variables] + } + } + + def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict: + """ + Convert API Based Extension to HTTP Request Node + :param external_data_variables: list of external data variables + :return: + """ + # TODO: implement + pass + + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + """ + Convert datasets to Knowledge Retrieval Node + :param new_app_mode: new app mode + :param dataset: dataset + :return: + """ + # TODO: implement + if new_app_mode == AppMode.CHAT: + query_variable_selector = ["start", "sys.query"] + else: + pass + + return { + "id": "knowledge-retrieval", + "position": None, + "data": { + "title": "KNOWLEDGE RETRIEVAL", + "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + } + } + + def _convert_to_llm_node(self, model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadEntity] = None) -> dict: + """ + Convert to LLM Node + :param model_config: model config + :param prompt_template: prompt template + :param file_upload: file upload config (optional) + """ + # TODO: implement + pass + + def _convert_to_end_node(self, app_model: App) -> dict: + """ + Convert to End Node + :param app_model: App instance + :return: + """ + if app_model.mode == AppMode.CHAT.value: + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + } + } + elif app_model.mode == "completion": + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": { + "type": EndNodeOutputType.PLAIN_TEXT.value, + "plain_text_selector": ["llm", "text"] + } + } + } + + def _create_edge(self, source: str, target: str) -> dict: + """ + Create Edge + :param source: source node id + :param target: target node id + :return: + """ + return { + "id": f"{source}-{target}", + "source": source, + "target": target + } + + def _append_node(self, graph: dict, node: dict) -> dict: + """ + Append Node to Graph + + :param graph: Graph, include: nodes, edges + :param node: Node to append + :return: + """ + previous_node = graph['nodes'][-1] + graph['nodes'].append(node) + graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + return graph + + def _get_new_app_mode(self, app_model: App) -> AppMode: + """ + Get new app mode + :param app_model: App instance + :return: AppMode + """ + if app_model.mode == "completion": + return AppMode.WORKFLOW + else: + return AppMode.value_of(app_model.mode) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py new file mode 100644 index 0000000000..6a967e86ff --- /dev/null +++ b/api/services/workflow_service.py @@ -0,0 +1,83 @@ +import json +from datetime import datetime + +from extensions.ext_database import db +from models.account import Account +from models.model import App, ChatbotAppEngine +from models.workflow import Workflow, WorkflowType +from services.workflow.defaults import default_block_configs +from services.workflow.workflow_converter import WorkflowConverter + + +class WorkflowService: + """ + Workflow Service + """ + + def get_draft_workflow(self, app_model: App) -> Workflow: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + """ + Sync draft workflow + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + # create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=WorkflowType.from_app_mode(app_model.mode).value, + version='draft', + graph=json.dumps(graph), + created_by=account.id + ) + db.session.add(workflow) + # update draft workflow if found + else: + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.utcnow() + + # commit db session changes + db.session.commit() + + # return draft workflow + return workflow + + def get_default_block_configs(self) -> dict: + """ + Get default block configs + """ + # return default block config + return default_block_configs + + def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: + """ + basic mode of chatbot app to workflow + + :param app_model: App instance + :return: + """ + # check if chatbot app is in basic mode + if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: + raise ValueError('Chatbot app already in workflow mode') + + # convert to workflow mode + workflow_converter = WorkflowConverter() + workflow = workflow_converter.convert_to_workflow(app_model=app_model) + + return workflow From c028e5f889b835f5bf8ed84e4f2ccad7879b3d0c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 03:20:28 +0800 Subject: [PATCH 046/200] add app convert codes --- api/controllers/console/app/conversation.py | 2 +- api/controllers/console/app/message.py | 2 +- api/controllers/console/app/workflow.py | 6 +- api/controllers/console/app/wraps.py | 2 +- api/core/app_runner/app_runner.py | 17 +- api/core/app_runner/basic_app_runner.py | 2 +- api/core/application_manager.py | 6 +- api/core/entities/application_entities.py | 1 - api/core/prompt/advanced_prompt_transform.py | 198 +++++++ .../generate_prompts/baichuan_chat.json | 6 +- .../generate_prompts/baichuan_completion.json | 4 +- .../prompt/generate_prompts/common_chat.json | 6 +- .../generate_prompts/common_completion.json | 4 +- api/core/prompt/prompt_builder.py | 10 - api/core/prompt/prompt_template.py | 3 +- api/core/prompt/prompt_transform.py | 552 +----------------- api/core/prompt/simple_prompt_transform.py | 298 ++++++++++ api/fields/annotation_fields.py | 1 - api/fields/workflow_fields.py | 1 - api/services/workflow/workflow_converter.py | 168 +++++- 20 files changed, 696 insertions(+), 593 deletions(-) create mode 100644 api/core/prompt/advanced_prompt_transform.py delete mode 100644 api/core/prompt/prompt_builder.py create mode 100644 api/core/prompt/simple_prompt_transform.py diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5d312149f7..daf9641121 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -21,7 +21,7 @@ from fields.conversation_fields import ( ) from libs.helper import datetime_string from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 9a177116ea..c384e878aa 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -26,7 +26,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import Conversation, Message, MessageAnnotation, MessageFeedback, AppMode +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 2794735bbb..1bb0ea34c1 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse, marshal_with +from flask_restful import Resource, marshal_with, reqparse from controllers.console import api from controllers.console.app.error import DraftWorkflowNotExist @@ -6,8 +6,8 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_fields import workflow_fields -from libs.login import login_required, current_user -from models.model import App, ChatbotAppEngine, AppMode +from libs.login import current_user, login_required +from models.model import App, AppMode, ChatbotAppEngine from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index fe35e72304..1c2c4cf5c7 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,7 +5,7 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, ChatbotAppEngine, AppMode +from models.model import App, AppMode, ChatbotAppEngine def get_app_model(view: Optional[Callable] = None, *, diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index f9678b372f..c6f6268a7a 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,7 +22,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import App, Message, MessageAnnotation @@ -140,12 +140,11 @@ class AppRunner: :param memory: memory :return: """ - prompt_transform = PromptTransform() + prompt_transform = SimplePromptTransform() # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: prompt_messages, stop = prompt_transform.get_prompt( - app_mode=app_record.mode, prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -155,17 +154,7 @@ class AppRunner: model_config=model_config ) else: - prompt_messages = prompt_transform.get_advanced_prompt( - app_mode=app_record.mode, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - stop = model_config.stop + raise NotImplementedError("Advanced prompt is not supported yet.") return prompt_messages, stop diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 26e9cc84aa..0e0fe6e3bf 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -15,7 +15,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from extensions.ext_database import db -from models.model import App, Conversation, Message, AppMode +from models.model import App, AppMode, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 2fde422d47..cf463be1df 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -28,7 +28,8 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, SensitiveWordAvoidanceEntity, - TextToSpeechEntity, VariableEntity, + TextToSpeechEntity, + VariableEntity, ) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -541,8 +542,7 @@ class ApplicationManager: query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] - ), - single_strategy=datasets.get('strategy', 'router') + ) ) ) else: diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 092591a73f..f8f293d96a 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -156,7 +156,6 @@ class DatasetRetrieveConfigEntity(BaseModel): query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy - single_strategy: Optional[str] = None # for temp top_k: Optional[int] = None score_threshold: Optional[float] = None reranking_model: Optional[dict] = None diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py new file mode 100644 index 0000000000..9ca3ef0375 --- /dev/null +++ b/api/core/prompt/advanced_prompt_transform.py @@ -0,0 +1,198 @@ +from typing import Optional + +from core.entities.application_entities import PromptTemplateEntity, ModelConfigEntity, \ + AdvancedCompletionPromptTemplateEntity +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, UserPromptMessage, \ + SystemPromptMessage, AssistantPromptMessage, TextPromptMessageContent +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from core.prompt.simple_prompt_transform import ModelMode + + +class AdvancePromptTransform(PromptTransform): + """ + Advanced Prompt Transform for Workflow LLM Node. + """ + + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + prompt_messages = [] + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + prompt_messages = self._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + elif model_mode == ModelMode.CHAT: + prompt_messages = self._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages + + def _get_completion_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get completion model prompt messages. + """ + raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + + prompt_messages = [] + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix + self._set_histories_variable( + memory=memory, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + prompt = prompt_template.format( + prompt_inputs + ) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) + + return prompt_messages + + def _get_chat_model_prompt_messages(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> list[PromptMessage]: + """ + Get chat model prompt messages. + """ + raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item.text + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs + ) + + if prompt_item.role == PromptMessageRole.USER: + prompt_messages.append(UserPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + elif prompt_item.role == PromptMessageRole.ASSISTANT: + prompt_messages.append(AssistantPromptMessage(content=prompt)) + + if memory: + self._append_chat_histories(memory, prompt_messages, model_config) + + if files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + elif files: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + + return prompt_messages + + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#context#' in prompt_template.variable_keys: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#query#' in prompt_template.variable_keys: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + def _set_histories_variable(self, memory: TokenBufferMemory, + raw_prompt: str, + role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigEntity) -> None: + if '#histories#' in prompt_template.variable_keys: + if memory: + inputs = {'#histories#': '', **prompt_inputs} + prompt_template = PromptTemplateParser(raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage( + content=prompt_template.format(prompt_inputs) + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + human_prefix=role_prefix.user, + ai_prefix=role_prefix.assistant + ) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/generate_prompts/baichuan_chat.json index 5bf83cd9c7..03b6a53cff 100644 --- a/api/core/prompt/generate_prompts/baichuan_chat.json +++ b/api/core/prompt/generate_prompts/baichuan_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "用户", "assistant_prefix": "助手", - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n\n", - "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n", + "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\n用户:{{query}}", + "query_prompt": "\n\n用户:{{#query#}}", "stops": ["用户:"] } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/generate_prompts/baichuan_completion.json index a3a2054e83..ae8c0dac53 100644 --- a/api/core/prompt/generate_prompts/baichuan_completion.json +++ b/api/core/prompt/generate_prompts/baichuan_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{context}}\n```\n", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json index 709a8d8866..d398a512e6 100644 --- a/api/core/prompt/generate_prompts/common_chat.json +++ b/api/core/prompt/generate_prompts/common_chat.json @@ -1,13 +1,13 @@ { "human_prefix": "Human", "assistant_prefix": "Assistant", - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", - "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt", "histories_prompt" ], - "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", + "query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ", "stops": ["\nHuman:", ""] } diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json index 9e7e8d68ef..c148772010 100644 --- a/api/core/prompt/generate_prompts/common_completion.json +++ b/api/core/prompt/generate_prompts/common_completion.json @@ -1,9 +1,9 @@ { - "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", "system_prompt_orders": [ "context_prompt", "pre_prompt" ], - "query_prompt": "{{query}}", + "query_prompt": "{{#query#}}", "stops": null } \ No newline at end of file diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py deleted file mode 100644 index 7727b0f92e..0000000000 --- a/api/core/prompt/prompt_builder.py +++ /dev/null @@ -1,10 +0,0 @@ -from core.prompt.prompt_template import PromptTemplateParser - - -class PromptBuilder: - @classmethod - def parse_prompt(cls, prompt: str, inputs: dict) -> str: - prompt_template = PromptTemplateParser(prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt = prompt_template.format(prompt_inputs) - return prompt diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py index 32c5a791de..454f92e3b7 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/prompt_template.py @@ -32,7 +32,8 @@ class PromptTemplateParser: return PromptTemplateParser.remove_template_variables(value) return value - return re.sub(REGEX, replacer, self.template) + prompt = re.sub(REGEX, replacer, self.template) + return re.sub(r'<\|.*?\|>', '', prompt) @classmethod def remove_template_variables(cls, text: str): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index abbfa96249..c0f70ae0bb 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,393 +1,13 @@ -import enum -import json -import os -import re from typing import Optional, cast -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.file.file_obj import FileObj +from core.entities.application_entities import ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) +from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import PromptTemplateParser -from models.model import AppMode - - -class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' - - @classmethod - def value_of(cls, value: str) -> 'ModelMode': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid mode value {value}') class PromptTransform: - def get_prompt(self, - app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name( - app_mode=app_mode, - provider=model_config.provider, - model=model_config.model - )) - - if app_mode == AppMode.CHAT and model_mode == ModelMode.CHAT: - stops = None - - prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - else: - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None - - prompt_messages = self._get_simple_others_prompt_messages( - prompt_rules=prompt_rules, - pre_prompt=prompt_template_entity.simple_prompt_template, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - return prompt_messages, stops - - def get_advanced_prompt(self, app_mode: str, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - app_mode = AppMode.value_of(app_mode) - model_mode = ModelMode.value_of(model_config.mode) - - prompt_messages = [] - - if app_mode == AppMode.CHAT: - if model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_chat_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif model_mode == ModelMode.CHAT: - prompt_messages = self._get_chat_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - query=query, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == ModelMode.CHAT: - prompt_messages = self._get_completion_app_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - context=context, - ) - elif model_mode == ModelMode.COMPLETION: - prompt_messages = self._get_completion_app_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, - inputs=inputs, - context=context, - ) - - return prompt_messages - - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: - """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } - - if human_prefix: - kwargs['human_prefix'] = human_prefix - - if ai_prefix: - kwargs['ai_prefix'] = ai_prefix - - return memory.get_history_prompt_text( - **kwargs - ) - - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> list[PromptMessage]: - """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit - ) - - def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: - # baichuan - if provider == 'baichuan': - return self._prompt_file_name_for_baichuan(app_mode) - - baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): - return self._prompt_file_name_for_baichuan(app_mode) - - # common - if app_mode == AppMode.COMPLETION: - return 'common_completion' - else: - return 'common_chat' - - def _prompt_file_name_for_baichuan(self, app_mode: AppMode) -> str: - if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' - else: - return 'baichuan_chat' - - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: - # Get the absolute path of the subdirectory - prompt_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'generate_prompts') - - json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') - # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: - return json.load(json_file) - - def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - prompt_messages = [] - - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - - self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_simple_others_prompt_messages(self, prompt_rules: dict, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - files: list[FileObj], - model_config: ModelConfigEntity) -> list[PromptMessage]: - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' - - if memory and 'histories_prompt' in prompt_rules: - # append chat histories - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=prompt + query_prompt, - inputs={ - 'query': query - } - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' - ) - prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format({'histories': histories}) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' - elif order == 'histories_prompt': - prompt += histories_prompt_content - - prompt_template = PromptTemplateParser(template=query_prompt) - query_prompt_content = prompt_template.format({'query': query}) - - prompt += query_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - model_mode = ModelMode.value_of(model_config.mode) - - if model_mode == ModelMode.CHAT and files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message = UserPromptMessage(content=prompt_message_contents) - else: - prompt_message = UserPromptMessage(content=prompt) - - return [prompt_message] - - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#context#' in prompt_template.variable_keys: - if context: - prompt_inputs['#context#'] = context - else: - prompt_inputs['#context#'] = '' - - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: - if '#query#' in prompt_template.variable_keys: - if query: - prompt_inputs['#query#'] = query - else: - prompt_inputs['#query#'] = '' - - def _set_histories_variable(self, memory: TokenBufferMemory, - raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: - if '#histories#' in prompt_template.variable_keys: - if memory: - tmp_human_message = UserPromptMessage( - content=PromptBuilder.parse_prompt( - prompt=raw_prompt, - inputs={'#histories#': '', **prompt_inputs} - ) - ) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) - - histories = self._get_history_messages_from_memory( - memory=memory, - max_token_limit=rest_tokens, - human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant - ) - prompt_inputs['#histories#'] = histories - else: - prompt_inputs['#histories#'] = '' - def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> None: @@ -422,152 +42,28 @@ class PromptTransform: return rest_tokens - def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str: - prompt = prompt_template.format( - prompt_inputs + def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None) -> str: + """Get memory messages.""" + kwargs = { + "max_token_limit": max_token_limit + } + + if human_prefix: + kwargs['human_prefix'] = human_prefix + + if ai_prefix: + kwargs['ai_prefix'] = ai_prefix + + return memory.get_history_prompt_text( + **kwargs ) - prompt = re.sub(r'<\|.*?\|>', '', prompt) - return prompt - - def _get_chat_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - self._set_query_variable(query, prompt_template, prompt_inputs) - - self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + max_token_limit: int) -> list[PromptMessage]: + """Get memory messages.""" + return memory.get_history_prompt_messages( + max_token_limit=max_token_limit ) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_chat_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list[FileObj], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - self._append_chat_histories(memory, prompt_messages, model_config) - - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _get_completion_app_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - context: Optional[str]) -> list[PromptMessage]: - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt - - prompt_messages = [] - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - prompt_messages.append(UserPromptMessage(content=prompt)) - - return prompt_messages - - def _get_completion_app_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - files: list[FileObj], - context: Optional[str]) -> list[PromptMessage]: - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages - - prompt_messages = [] - - for prompt_item in raw_prompt_list: - raw_prompt = prompt_item.text - - prompt_template = PromptTemplateParser(template=raw_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = self._format_prompt(prompt_template, prompt_inputs) - - if prompt_item.role == PromptMessageRole.USER: - prompt_messages.append(UserPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.SYSTEM and prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) - elif prompt_item.role == PromptMessageRole.ASSISTANT: - prompt_messages.append(AssistantPromptMessage(content=prompt)) - - for prompt_message in prompt_messages[::-1]: - if prompt_message.role == PromptMessageRole.USER: - if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) - - prompt_message.content = prompt_message_contents - break - - return prompt_messages diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py new file mode 100644 index 0000000000..a898c37c4a --- /dev/null +++ b/api/core/prompt/simple_prompt_transform.py @@ -0,0 +1,298 @@ +import enum +import json +import os +from typing import Optional, Tuple + +from core.entities.application_entities import ( + ModelConfigEntity, + PromptTemplateEntity, +) +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform +from models.model import AppMode + + +class ModelMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + + @classmethod + def value_of(cls, value: str) -> 'ModelMode': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + +prompt_file_contents = {} + + +class SimplePromptTransform(PromptTransform): + """ + Simple Prompt Transform for Chatbot App Basic Mode. + """ + def get_prompt(self, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) -> \ + tuple[list[PromptMessage], Optional[list[str]]]: + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.CHAT: + prompt_messages, stops = self._get_chat_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + else: + prompt_messages, stops = self._get_completion_model_prompt_messages( + pre_prompt=prompt_template_entity.simple_prompt_template, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + return prompt_messages, stops + + def get_prompt_str_and_rules(self, app_mode: AppMode, + model_config: ModelConfigEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> Tuple[str, dict]: + # get prompt template + prompt_template_config = self.get_prompt_template( + app_mode=app_mode, + provider=model_config.provider, + model=model_config.model, + pre_prompt=pre_prompt, + has_context=context is not None, + query_in_prompt=query is not None, + with_memory_prompt=histories is not None + ) + + variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + + for v in prompt_template_config['special_variable_keys']: + # support #context#, #query# and #histories# + if v == '#context#': + variables['#context#'] = context if context else '' + elif v == '#query#': + variables['#query#'] = query if query else '' + elif v == '#histories#': + variables['#histories#'] = histories if histories else '' + + prompt_template = prompt_template_config['prompt_template'] + prompt = prompt_template.format(variables) + + return prompt, prompt_template_config['prompt_rules'] + + def get_prompt_template(self, app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False) -> dict: + prompt_rules = self._get_prompt_rule( + app_mode=app_mode, + provider=provider, + model=model + ) + + custom_variable_keys = [] + special_variable_keys = [] + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt' and has_context: + prompt += prompt_rules['context_prompt'] + special_variable_keys.append('#context#') + elif order == 'pre_prompt' and pre_prompt: + prompt += pre_prompt + '\n' + pre_prompt_template = PromptTemplateParser(template=pre_prompt) + custom_variable_keys = pre_prompt_template.variable_keys + elif order == 'histories_prompt' and with_memory_prompt: + prompt += prompt_rules['histories_prompt'] + special_variable_keys.append('#histories#') + + if query_in_prompt: + prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' + special_variable_keys.append('#query#') + + return { + "prompt_template": PromptTemplateParser(template=prompt), + "custom_variable_keys": custom_variable_keys, + "special_variable_keys": special_variable_keys, + "prompt_rules": prompt_rules + } + + def _get_chat_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + prompt_messages = [] + + # get prompt + prompt, _ = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if prompt: + prompt_messages.append(SystemPromptMessage(content=prompt)) + + self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) + + prompt_messages.append(self.get_last_user_message(query, files)) + + return prompt_messages, None + + def _get_completion_model_prompt_messages(self, pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ + -> Tuple[list[PromptMessage], Optional[list[str]]]: + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context + ) + + if memory: + tmp_human_message = UserPromptMessage( + content=prompt + ) + + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + histories = self._get_history_messages_from_memory( + memory=memory, + max_token_limit=rest_tokens, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + ) + + # get prompt + prompt, prompt_rules = self.get_prompt_str_and_rules( + app_mode=AppMode.CHAT, + model_config=model_config, + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + context=context, + histories=histories + ) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return [self.get_last_user_message(prompt, files)], stops + + def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage: + if files: + prompt_message_contents = [TextPromptMessageContent(data=prompt)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_message = UserPromptMessage(content=prompt_message_contents) + else: + prompt_message = UserPromptMessage(content=prompt) + + return prompt_message + + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict: + """ + Get simple prompt rule. + :param app_mode: app mode + :param provider: model provider + :param model: model name + :return: + """ + prompt_file_name = self._prompt_file_name( + app_mode=app_mode, + provider=provider, + model=model + ) + + # Check if the prompt file is already loaded + if prompt_file_name in prompt_file_contents: + return prompt_file_contents[prompt_file_name] + + # Get the absolute path of the subdirectory + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + + # Open the JSON file and read its content + with open(json_file_path, encoding='utf-8') as json_file: + content = json.load(json_file) + + # Store the content of the prompt file + prompt_file_contents[prompt_file_name] = content + + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: + # baichuan + is_baichuan = False + if provider == 'baichuan': + is_baichuan = True + else: + baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] + if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + is_baichuan = True + + if is_baichuan: + if app_mode == AppMode.WORKFLOW: + return 'baichuan_completion' + else: + return 'baichuan_chat' + + # common + if app_mode == AppMode.WORKFLOW: + return 'common_completion' + else: + return 'common_chat' diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index d9cd6c03bb..c778084475 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,7 +2,6 @@ from flask_restful import fields from libs.helper import TimestampField - annotation_fields = { "id": fields.String, "question": fields.String, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9dc92ea43b..decdc0567f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,7 +5,6 @@ from flask_restful import fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField - workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c2fad83aaf..7d18f4f675 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -2,9 +2,17 @@ import json from typing import Optional from core.application_manager import ApplicationManager -from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, FileUploadEntity, \ - ExternalDataVariableEntity, DatasetEntity, VariableEntity +from core.entities.application_entities import ( + DatasetEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, DatasetRetrieveConfigEntity, +) +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper +from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db @@ -32,6 +40,9 @@ class WorkflowConverter: :param account: Account instance :return: workflow instance """ + # get new app mode + new_app_mode = self._get_new_app_mode(app_model) + # get original app config app_model_config = app_model.app_model_config @@ -75,14 +86,17 @@ class WorkflowConverter: # convert to knowledge retrieval node if app_model_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - dataset=app_model_config.dataset, - show_retrieve_source=app_model_config.show_retrieve_source + new_app_mode=new_app_mode, + dataset_config=app_model_config.dataset ) - graph = self._append_node(graph, knowledge_retrieval_node) + if knowledge_retrieval_node: + graph = self._append_node(graph, knowledge_retrieval_node) # convert to llm node llm_node = self._convert_to_llm_node( + new_app_mode=new_app_mode, + graph=graph, model_config=app_model_config.model_config, prompt_template=app_model_config.prompt_template, file_upload=app_model_config.file_upload @@ -95,14 +109,11 @@ class WorkflowConverter: graph = self._append_node(graph, end_node) - # get new app mode - app_mode = self._get_new_app_mode(app_model) - # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, app_id=app_model.id, - type=WorkflowType.from_app_mode(app_mode).value, + type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), created_by=account.id @@ -124,7 +135,7 @@ class WorkflowConverter: new_app_model_config.completion_prompt_config = '' new_app_model_config.dataset_configs = '' new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value + if new_app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value new_app_model_config.workflow_id = workflow.id db.session.add(new_app_model_config) @@ -157,18 +168,22 @@ class WorkflowConverter: # TODO: implement pass - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset: DatasetEntity) -> dict: + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ + -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode - :param dataset: dataset + :param dataset_config: dataset :return: """ - # TODO: implement + retrieve_config = dataset_config.retrieve_config if new_app_mode == AppMode.CHAT: query_variable_selector = ["start", "sys.query"] + elif retrieve_config.query_variable: + # fetch query variable + query_variable_selector = ["start", retrieve_config.query_variable] else: - pass + return None return { "id": "knowledge-retrieval", @@ -176,20 +191,139 @@ class WorkflowConverter: "data": { "title": "KNOWLEDGE RETRIEVAL", "type": NodeType.KNOWLEDGE_RETRIEVAL.value, + "query_variable_selector": query_variable_selector, + "dataset_ids": dataset_config.dataset_ids, + "retrieval_mode": retrieve_config.retrieve_strategy.value, + "multiple_retrieval_config": { + "top_k": retrieve_config.top_k, + "score_threshold": retrieve_config.score_threshold, + "reranking_model": retrieve_config.reranking_model + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, } } - def _convert_to_llm_node(self, model_config: ModelConfigEntity, + def _convert_to_llm_node(self, new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: Optional[FileUploadEntity] = None) -> dict: """ Convert to LLM Node + :param new_app_mode: new app mode + :param graph: graph :param model_config: model config :param prompt_template: prompt template :param file_upload: file upload config (optional) """ - # TODO: implement - pass + # fetch start and knowledge retrieval node + start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) + knowledge_retrieval_node = next(filter( + lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, + graph['nodes'] + ), None) + + role_prefix = None + + # Chat Model + if model_config.mode == LLMMode.CHAT.value: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = [ + { + "role": 'user', + "text": prompt_template_config['prompt_template'].template + } + ] + else: + advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template + prompts = [helper.dump_model(m) for m in advanced_chat_prompt_template.messages] \ + if advanced_chat_prompt_template else [] + # Completion Model + else: + if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + # get prompt template + prompt_transform = SimplePromptTransform() + prompt_template_config = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider=model_config.provider, + model=model_config.model, + pre_prompt=prompt_template.simple_prompt_template, + has_context=knowledge_retrieval_node is not None, + query_in_prompt=False + ) + prompts = { + "text": prompt_template_config['prompt_template'].template + } + + prompt_rules = prompt_template_config['prompt_rules'] + role_prefix = { + "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + } + else: + advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template + prompts = { + "text": advanced_completion_prompt_template.prompt, + } if advanced_completion_prompt_template else {"text": ""} + + if advanced_completion_prompt_template.role_prefix: + role_prefix = { + "user": advanced_completion_prompt_template.role_prefix.user, + "assistant": advanced_completion_prompt_template.role_prefix.assistant + } + + memory = None + if new_app_mode == AppMode.CHAT: + memory = { + "role_prefix": role_prefix, + "window": { + "enabled": False + } + } + + return { + "id": "llm", + "position": None, + "data": { + "title": "LLM", + "type": NodeType.LLM.value, + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": model_config.parameters.update({"stop": model_config.stop}) + }, + "variables": [{ + "variable": v['variable'], + "value_selector": ["start", v['variable']] + } for v in start_node['data']['variables']], + "prompts": prompts, + "memory": memory, + "context": { + "enabled": knowledge_retrieval_node is not None, + "variable_selector": ["knowledge-retrieval", "result"] + if knowledge_retrieval_node is not None else None + }, + "vision": { + "enabled": file_upload is not None, + "variable_selector": ["start", "sys.files"] if file_upload is not None else None, + "configs": { + "detail": file_upload.image_config['detail'] + } if file_upload is not None else None + } + } + } def _convert_to_end_node(self, app_model: App) -> dict: """ From 8642354a2aaf7ad6b758048c9d50ef5ee5efb195 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 03:20:39 +0800 Subject: [PATCH 047/200] lint --- api/core/prompt/advanced_prompt_transform.py | 17 +++++++++++++---- api/core/prompt/simple_prompt_transform.py | 8 ++++---- api/services/workflow/workflow_converter.py | 3 ++- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 9ca3ef0375..397f708f1f 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,11 +1,20 @@ from typing import Optional -from core.entities.application_entities import PromptTemplateEntity, ModelConfigEntity, \ - AdvancedCompletionPromptTemplateEntity +from core.entities.application_entities import ( + AdvancedCompletionPromptTemplateEntity, + ModelConfigEntity, + PromptTemplateEntity, +) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, UserPromptMessage, \ - SystemPromptMessage, AssistantPromptMessage, TextPromptMessageContent +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a898c37c4a..6e158bef39 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,7 @@ import enum import json import os -from typing import Optional, Tuple +from typing import Optional from core.entities.application_entities import ( ModelConfigEntity, @@ -85,7 +85,7 @@ class SimplePromptTransform(PromptTransform): query: Optional[str] = None, context: Optional[str] = None, histories: Optional[str] = None, - ) -> Tuple[str, dict]: + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -160,7 +160,7 @@ class SimplePromptTransform(PromptTransform): files: list[FileObj], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) \ - -> Tuple[list[PromptMessage], Optional[list[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -193,7 +193,7 @@ class SimplePromptTransform(PromptTransform): files: list[FileObj], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) \ - -> Tuple[list[PromptMessage], Optional[list[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=AppMode.CHAT, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 7d18f4f675..647713b404 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -4,11 +4,12 @@ from typing import Optional from core.application_manager import ApplicationManager from core.entities.application_entities import ( DatasetEntity, + DatasetRetrieveConfigEntity, ExternalDataVariableEntity, FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, DatasetRetrieveConfigEntity, + VariableEntity, ) from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper From 3b234febf5a04565b92590ec077b079fd20a4578 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 15:15:42 +0800 Subject: [PATCH 048/200] fix bugs and add unit tests --- api/core/prompt/simple_prompt_transform.py | 35 +-- api/models/workflow.py | 4 +- api/tests/unit_tests/.gitignore | 1 + api/tests/unit_tests/__init__.py | 0 api/tests/unit_tests/conftest.py | 7 + api/tests/unit_tests/core/__init__.py | 0 api/tests/unit_tests/core/prompt/__init__.py | 0 .../core/prompt/test_prompt_transform.py | 47 ++++ .../prompt/test_simple_prompt_transform.py | 216 ++++++++++++++++++ 9 files changed, 292 insertions(+), 18 deletions(-) create mode 100644 api/tests/unit_tests/.gitignore create mode 100644 api/tests/unit_tests/__init__.py create mode 100644 api/tests/unit_tests/conftest.py create mode 100644 api/tests/unit_tests/core/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/test_prompt_transform.py create mode 100644 api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 6e158bef39..a51cc86e8b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. """ + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, @@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform): } def _get_chat_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform): model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, - query=query, + query=None, context=context ) @@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, None def _get_completion_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( @@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform): provider=provider, model=model ) - + # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] @@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform): # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') - + # Open the JSON file and read its content with open(json_file_path, encoding='utf-8') as json_file: content = json.load(json_file) - + # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content + return content + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False diff --git a/api/models/workflow.py b/api/models/workflow.py index ed26e98896..95805e7871 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db from models.account import Account -from models.model import AppMode class WorkflowType(Enum): @@ -29,13 +28,14 @@ class WorkflowType(Enum): raise ValueError(f'invalid workflow type value {value}') @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': """ Get workflow type from app mode. :param app_mode: app mode :return: workflow type """ + from models.model import AppMode app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT diff --git a/api/tests/unit_tests/.gitignore b/api/tests/unit_tests/.gitignore new file mode 100644 index 0000000000..426667562b --- /dev/null +++ b/api/tests/unit_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py new file mode 100644 index 0000000000..afc9802cf1 --- /dev/null +++ b/api/tests/unit_tests/conftest.py @@ -0,0 +1,7 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) diff --git a/api/tests/unit_tests/core/__init__.py b/api/tests/unit_tests/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/prompt/__init__.py b/api/tests/unit_tests/core/prompt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py new file mode 100644 index 0000000000..8a260b0507 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.entities.provider_configuration import ProviderModelBundle +from core.model_runtime.entities.message_entities import UserPromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +def test__calculate_rest_token(): + model_schema_mock = MagicMock(spec=AIModelEntity) + parameter_rule_mock = MagicMock(spec=ParameterRule) + parameter_rule_mock.name = 'max_tokens' + model_schema_mock.parameter_rules = [ + parameter_rule_mock + ] + model_schema_mock.model_properties = { + ModelPropertyKey.CONTEXT_SIZE: 62 + } + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens.return_value = 6 + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.model = 'gpt-4' + model_config_mock.credentials = {} + model_config_mock.parameters = { + 'max_tokens': 50 + } + model_config_mock.model_schema = model_schema_mock + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + prompt_transform = PromptTransform() + + prompt_messages = [UserPromptMessage(content="Hello, how are you?")] + rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + + # Validate based on the mock configuration and expected logic + expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters['max_tokens'] + - large_language_model_mock.get_num_tokens.return_value) + assert rest_tokens == expected_rest_tokens + assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py new file mode 100644 index 0000000000..cb6ad02541 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.prompt.simple_prompt_transform import SimplePromptTransform +from models.model import AppMode + + +def test_get_common_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_baichuan_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_common_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_baichuan_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + print(prompt_template['prompt_template'].template) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_q(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] + assert prompt_template['special_variable_keys'] == ['#query#'] + + +def test_get_common_chat_app_prompt_template_with_cq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_p(): + prompt_transform = SimplePromptTransform() + pre_prompt = "you are {{name}}" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=False, + with_memory_prompt=False, + ) + assert prompt_template['prompt_template'].template == pre_prompt + '\n' + assert prompt_template['custom_variable_keys'] == ['name'] + assert prompt_template['special_variable_keys'] == [] + + +def test__get_chat_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=False, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context} + real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 2 + assert prompt_messages[0].content == real_system_prompt + assert prompt_messages[1].content == query + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context, '#query#': query} + real_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 1 + assert stops == prompt_template['prompt_rules'].get('stops') + assert prompt_messages[0].content == real_prompt From 6aecf42b6e5d05659ba589f62dc1d6645ba85de9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 22:32:33 +0800 Subject: [PATCH 049/200] fix prompt transform bugs --- api/core/prompt/advanced_prompt_transform.py | 26 ++- api/core/prompt/prompt_transform.py | 4 +- api/core/prompt/simple_prompt_transform.py | 2 +- .../prompt/test_advanced_prompt_transform.py | 193 ++++++++++++++++++ .../prompt/test_simple_prompt_transform.py | 46 ++++- 5 files changed, 251 insertions(+), 20 deletions(-) create mode 100644 api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 397f708f1f..0ed9ec352c 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -20,7 +20,7 @@ from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode -class AdvancePromptTransform(PromptTransform): +class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ @@ -74,10 +74,10 @@ class AdvancePromptTransform(PromptTransform): prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - self._set_histories_variable( + prompt_inputs = self._set_histories_variable( memory=memory, raw_prompt=raw_prompt, role_prefix=role_prefix, @@ -104,7 +104,7 @@ class AdvancePromptTransform(PromptTransform): def _get_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - query: str, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -122,7 +122,7 @@ class AdvancePromptTransform(PromptTransform): prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt = prompt_template.format( prompt_inputs @@ -136,7 +136,7 @@ class AdvancePromptTransform(PromptTransform): prompt_messages.append(AssistantPromptMessage(content=prompt)) if memory: - self._append_chat_histories(memory, prompt_messages, model_config) + prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -157,7 +157,7 @@ class AdvancePromptTransform(PromptTransform): last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -165,26 +165,30 @@ class AdvancePromptTransform(PromptTransform): return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#context#' in prompt_template.variable_keys: if context: prompt_inputs['#context#'] = context else: prompt_inputs['#context#'] = '' - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + return prompt_inputs + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#query#' in prompt_template.variable_keys: if query: prompt_inputs['#query#'] = query else: prompt_inputs['#query#'] = '' + return prompt_inputs + def _set_histories_variable(self, memory: TokenBufferMemory, raw_prompt: str, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} @@ -205,3 +209,5 @@ class AdvancePromptTransform(PromptTransform): prompt_inputs['#histories#'] = histories else: prompt_inputs['#histories#'] = '' + + return prompt_inputs diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index c0f70ae0bb..9596976b6e 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -10,12 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> list[PromptMessage]: if memory: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) + return prompt_messages + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: rest_tokens = 2000 diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a51cc86e8b..2f98fbcae8 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -177,7 +177,7 @@ class SimplePromptTransform(PromptTransform): if prompt: prompt_messages.append(SystemPromptMessage(content=prompt)) - self._append_chat_histories( + prompt_messages = self._append_chat_histories( memory=memory, prompt_messages=prompt_messages, model_config=model_config diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py new file mode 100644 index 0000000000..65a160a8e5 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ + ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity +from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.prompt_template import PromptTemplateParser +from models.model import Conversation + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt=prompt_template, + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + inputs = { + "name": "John" + } + files = [] + context = "I am superman." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 1 + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ + "#context#": context, + "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " + f"{prompt.content}" for prompt in history_prompt_messages]), + **inputs, + }) + + +def test__get_chat_model_prompt_messages(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + query = "Hi2." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi1."), + AssistantPromptMessage(content="Hello1!") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 6 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert prompt_messages[5].content == query + + +def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 3 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + + +def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [ + FileObj( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="https://example.com/image1.jpg", + file_config={ + "image": { + "detail": "high", + } + } + ) + ] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert isinstance(prompt_messages[3].content, list) + assert len(prompt_messages[3].content) == 2 + assert prompt_messages[3].content[1].data == files[0].url + + +@pytest.fixture +def get_chat_model_args(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ) + ) + + inputs = { + "name": "John" + } + + context = "I am superman." + + return model_config_mock, prompt_template_entity, inputs, context diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index cb6ad02541..c174983e38 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,8 +1,10 @@ from unittest.mock import MagicMock from core.entities.application_entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import AppMode +from models.model import AppMode, Conversation def test_get_common_chat_app_prompt_template_with_pcqm(): @@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' + memory_mock = MagicMock(spec=TokenBufferMemory) + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory_mock.get_history_prompt_messages.return_value = history_prompt_messages + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory_mock, model_config=model_config_mock ) @@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages(): full_inputs = {**inputs, '#context#': context} real_system_prompt = prompt_template['prompt_template'].format(full_inputs) - assert len(prompt_messages) == 2 + assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt - assert prompt_messages[1].content == query + assert prompt_messages[1].content == history_prompt_messages[0].content + assert prompt_messages[2].content == history_prompt_messages[1].content + assert prompt_messages[3].content == query def test__get_completion_model_prompt_messages(): @@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory, model_config=model_config_mock ) @@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages(): pre_prompt=pre_prompt, has_context=True, query_in_prompt=True, - with_memory_prompt=False, + with_memory_prompt=True, ) - full_inputs = {**inputs, '#context#': context, '#query#': query} + prompt_rules = prompt_template['prompt_rules'] + full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( + max_token_limit=2000, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + )} real_prompt = prompt_template['prompt_template'].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_template['prompt_rules'].get('stops') + assert stops == prompt_rules.get('stops') assert prompt_messages[0].content == real_prompt From 45621ba4d7b8d95a6f2b78b27ad8ab3a04eb198a Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 23 Feb 2024 14:58:03 +0800 Subject: [PATCH 050/200] add api extension to http request node convert --- api/core/features/external_data_fetch.py | 7 - api/services/workflow/workflow_converter.py | 149 ++++++++++++++++++-- 2 files changed, 135 insertions(+), 21 deletions(-) diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 7f23c8ed72..ef37f05528 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -1,5 +1,4 @@ import concurrent -import json import logging from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -28,12 +27,6 @@ class ExternalDataFetchFeature: :param query: the query :return: the filled inputs """ - # Group tools by type and config - grouped_tools = {} - for tool in external_data_tools: - tool_key = (tool.type, json.dumps(tool.config, sort_keys=True)) - grouped_tools.setdefault(tool_key, []).append(tool) - results = {} with ThreadPoolExecutor() as executor: futures = {} diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 647713b404..1fb37afe01 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -11,6 +11,7 @@ from core.entities.application_entities import ( PromptTemplateEntity, VariableEntity, ) +from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils import helper from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -18,6 +19,7 @@ from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db from models.account import Account +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType @@ -49,7 +51,7 @@ class WorkflowConverter: # convert app model config application_manager = ApplicationManager() - application_manager.convert_from_app_model_config_dict( + app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, app_model_config_dict=app_model_config.to_dict() ) @@ -71,24 +73,27 @@ class WorkflowConverter: # convert to start node start_node = self._convert_to_start_node( - variables=app_model_config.variables + variables=app_orchestration_config_entity.variables ) graph['nodes'].append(start_node) # convert to http request node - if app_model_config.external_data_variables: - http_request_node = self._convert_to_http_request_node( - external_data_variables=app_model_config.external_data_variables + if app_orchestration_config_entity.external_data_variables: + http_request_nodes = self._convert_to_http_request_node( + app_model=app_model, + variables=app_orchestration_config_entity.variables, + external_data_variables=app_orchestration_config_entity.external_data_variables ) - graph = self._append_node(graph, http_request_node) + for http_request_node in http_request_nodes: + graph = self._append_node(graph, http_request_node) # convert to knowledge retrieval node - if app_model_config.dataset: + if app_orchestration_config_entity.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=app_model_config.dataset + dataset_config=app_orchestration_config_entity.dataset ) if knowledge_retrieval_node: @@ -98,9 +103,9 @@ class WorkflowConverter: llm_node = self._convert_to_llm_node( new_app_mode=new_app_mode, graph=graph, - model_config=app_model_config.model_config, - prompt_template=app_model_config.prompt_template, - file_upload=app_model_config.file_upload + model_config=app_orchestration_config_entity.model_config, + prompt_template=app_orchestration_config_entity.prompt_template, + file_upload=app_orchestration_config_entity.file_upload ) graph = self._append_node(graph, llm_node) @@ -160,14 +165,130 @@ class WorkflowConverter: } } - def _convert_to_http_request_node(self, external_data_variables: list[ExternalDataVariableEntity]) -> dict: + def _convert_to_http_request_node(self, app_model: App, + variables: list[VariableEntity], + external_data_variables: list[ExternalDataVariableEntity]) -> list[dict]: """ Convert API Based Extension to HTTP Request Node + :param app_model: App instance + :param variables: list of variables :param external_data_variables: list of external data variables :return: """ - # TODO: implement - pass + index = 1 + nodes = [] + tenant_id = app_model.tenant_id + for external_data_variable in external_data_variables: + tool_type = external_data_variable.type + if tool_type != "api": + continue + + tool_variable = external_data_variable.variable + tool_config = external_data_variable.config + + # get params from config + api_based_extension_id = tool_config.get("api_based_extension_id") + + # get api_based_extension + api_based_extension = db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() + + if not api_based_extension: + raise ValueError("[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid" + .format(tool_variable)) + + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=tenant_id, + token=api_based_extension.api_key + ) + + http_request_variables = [] + inputs = {} + for v in variables: + http_request_variables.append({ + "variable": v.variable, + "value_selector": ["start", v.variable] + }) + + inputs[v.variable] = '{{' + v.variable + '}}' + + if app_model.mode == AppMode.CHAT.value: + http_request_variables.append({ + "variable": "_query", + "value_selector": ["start", "sys.query"] + }) + + request_body = { + 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + 'params': { + 'app_id': app_model.id, + 'tool_variable': tool_variable, + 'inputs': inputs, + 'query': '{{_query}}' if app_model.mode == AppMode.CHAT.value else '' + } + } + + request_body_json = json.dumps(request_body) + request_body_json = request_body_json.replace('\{\{', '{{').replace('\}\}', '}}') + + http_request_node = { + "id": f"http-request-{index}", + "position": None, + "data": { + "title": f"HTTP REQUEST {api_based_extension.name}", + "type": NodeType.HTTP_REQUEST.value, + "variables": http_request_variables, + "method": "post", + "url": api_based_extension.api_endpoint, + "authorization": { + "type": "api-key", + "config": { + "type": "bearer", + "api_key": api_key + } + }, + "headers": "", + "params": "", + "body": { + "type": "json", + "data": request_body_json + } + } + } + index += 1 + + nodes.append(http_request_node) + + # append code node for response body parsing + code_node = { + "id": f"code-{index}", + "position": None, + "data": { + "title": f"Parse {api_based_extension.name} response", + "type": NodeType.CODE.value, + "variables": [{ + "variable": "response_json", + "value_selector": [http_request_node['id'], "body"] + }], + "code_language": "python3", + "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" + "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "string" + } + ] + } + } + + nodes.append(code_node) + + return nodes def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ -> Optional[dict]: From 0806b3163ab45f8149acc493bb7b5c33095ebe65 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 23 Feb 2024 18:18:49 +0800 Subject: [PATCH 051/200] add to http request node convert tests --- api/core/application_manager.py | 8 +- api/core/entities/application_entities.py | 1 + api/services/app_model_config_service.py | 2 +- api/services/workflow/workflow_converter.py | 24 ++- api/tests/unit_tests/services/__init__.py | 0 .../unit_tests/services/workflow/__init__.py | 0 .../workflow/test_workflow_converter.py | 184 ++++++++++++++++++ 7 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/services/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_converter.py diff --git a/api/core/application_manager.py b/api/core/application_manager.py index cf463be1df..77bb81b0da 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -400,10 +400,14 @@ class ApplicationManager: config=val['config'] ) ) - elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: properties['variables'].append( VariableEntity( - type=VariableEntity.Type.TEXT_INPUT, + type=VariableEntity.Type.value_of(typ), variable=variable[typ].get('variable'), description=variable[typ].get('description'), label=variable[typ].get('label'), diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index f8f293d96a..667940f184 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -94,6 +94,7 @@ class VariableEntity(BaseModel): TEXT_INPUT = 'text-input' SELECT = 'select' PARAGRAPH = 'paragraph' + NUMBER = 'number' @classmethod def value_of(cls, value: str) -> 'VariableEntity.Type': diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3ac11c645c..aa8cd73ea7 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -205,7 +205,7 @@ class AppModelConfigService: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "external_data_tool"]: + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1fb37afe01..31df58a583 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -190,10 +190,10 @@ class WorkflowConverter: api_based_extension_id = tool_config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = self._get_api_based_extension( + tenant_id=tenant_id, + api_based_extension_id=api_based_extension_id + ) if not api_based_extension: raise ValueError("[External data tool] API query failed, variable: {}, " @@ -259,7 +259,6 @@ class WorkflowConverter: } } } - index += 1 nodes.append(http_request_node) @@ -268,7 +267,7 @@ class WorkflowConverter: "id": f"code-{index}", "position": None, "data": { - "title": f"Parse {api_based_extension.name} response", + "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, "variables": [{ "variable": "response_json", @@ -287,6 +286,7 @@ class WorkflowConverter: } nodes.append(code_node) + index += 1 return nodes @@ -513,3 +513,15 @@ class WorkflowConverter: return AppMode.WORKFLOW else: return AppMode.value_of(app_model.mode) + + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + """ + Get API Based Extension + :param tenant_id: tenant id + :param api_based_extension_id: api based extension id + :return: + """ + return db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/tests/unit_tests/services/__init__.py b/api/tests/unit_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/workflow/__init__.py b/api/tests/unit_tests/services/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py new file mode 100644 index 0000000000..69cf6afe45 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -0,0 +1,184 @@ +# test for api/services/workflow/workflow_converter.py +import json +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.helper import encrypter +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import AppMode +from services.workflow.workflow_converter import WorkflowConverter + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity( + variable="text-input", + label="text-input", + type=VariableEntity.Type.TEXT_INPUT + ), + VariableEntity( + variable="paragraph", + label="paragraph", + type=VariableEntity.Type.PARAGRAPH + ), + VariableEntity( + variable="select", + label="select", + type=VariableEntity.Type.SELECT + ) + ] + + +def test__convert_to_start_node(default_variables): + # act + result = WorkflowConverter()._convert_to_start_node(default_variables) + + # assert + assert result["data"]["variables"][0]["variable"] == "text-input" + assert result["data"]["variables"][1]["variable"] == "paragraph" + assert result["data"]["variables"][2]["variable"] == "select" + + +def test__convert_to_http_request_node(default_variables): + """ + Test convert to http request nodes + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 4 # appended _query variable + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "{{_query}}" # for chatbot + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_http_request_node_for_workflow_app(default_variables): + """ + Test convert to http request nodes for workflow app + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 3 + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "" + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" From f11bf9153deee59d773d30d073e272d22f0082bc Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 13:47:43 +0800 Subject: [PATCH 052/200] add more tests --- .../workflow/test_workflow_converter.py | 266 +++++++++++++++++- 1 file changed, 263 insertions(+), 3 deletions(-) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 69cf6afe45..ee9e5eb2fa 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,8 +4,12 @@ from unittest.mock import MagicMock import pytest -from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ + AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import AppMode from services.workflow.workflow_converter import WorkflowConverter @@ -42,9 +46,9 @@ def test__convert_to_start_node(default_variables): assert result["data"]["variables"][2]["variable"] == "select" -def test__convert_to_http_request_node(default_variables): +def test__convert_to_http_request_node_for_chatbot(default_variables): """ - Test convert to http request nodes + Test convert to http request nodes for chatbot :return: """ app_model = MagicMock() @@ -182,3 +186,259 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): code_node = nodes[1] assert code_node["data"]["type"] == "code" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot(): + new_app_mode = AppMode.CHAT + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", "sys.query"] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app(): + new_app_mode = AppMode.WORKFLOW + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ]) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], list) + assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages) + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], dict) + assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt From 7458fde5a51f593376aedeafb78a8cac9cdb146d Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 14:40:52 +0800 Subject: [PATCH 053/200] add agent app convert command --- api/commands.py | 55 ++++++++++++++++++++++++- api/controllers/console/app/workflow.py | 5 ++- api/services/workflow_service.py | 5 ++- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/api/commands.py b/api/commands.py index 250039a365..9a023b1c48 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, MessageAnnotation +from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel @@ -370,8 +370,61 @@ def migrate_knowledge_vector_database(): fg='green')) +@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +def convert_to_agent_apps(): + """ + Convert Agent Assistant to Agent App. + """ + click.echo(click.style('Start convert to agent apps.', fg='green')) + + proceeded_app_ids = [] + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id AS id FROM apps a +INNER JOIN app_model_configs am ON a.app_model_config_id=am.id +WHERE a.mode = 'chat' AND am.agent_mode is not null +and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') +and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).filter(App.id == app_id).first() + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo('Converting app: {}'.format(app.id)) + + try: + app.mode = AppMode.AGENT.value + db.session.commit() + + # update conversation mode to agent + db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + {Conversation.mode: AppMode.AGENT.value} + ) + + db.session.commit() + click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + except Exception as e: + click.echo( + click.style('Convert app error: {} {}'.format(e.__class__.__name__, + str(e)), fg='red')) + + click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) + app.cli.add_command(convert_to_agent_apps) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1bb0ea34c1..7663e22580 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -77,7 +77,10 @@ class ConvertToWorkflowApi(Resource): """ # convert to workflow mode workflow_service = WorkflowService() - workflow = workflow_service.chatbot_convert_to_workflow(app_model=app_model) + workflow = workflow_service.chatbot_convert_to_workflow( + app_model=app_model, + account=current_user + ) # return workflow return workflow diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6a967e86ff..0cb398225d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -65,11 +65,12 @@ class WorkflowService: # return default block config return default_block_configs - def chatbot_convert_to_workflow(self, app_model: App) -> Workflow: + def chatbot_convert_to_workflow(self, app_model: App, account: Account) -> Workflow: """ basic mode of chatbot app to workflow :param app_model: App instance + :param account: Account instance :return: """ # check if chatbot app is in basic mode @@ -78,6 +79,6 @@ class WorkflowService: # convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model) + workflow = workflow_converter.convert_to_workflow(app_model=app_model, account=account) return workflow From 2ba7ac8bc1f0b9d7cf49a2f5cd9d2f3bf19681a3 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 15:52:08 +0800 Subject: [PATCH 054/200] add expert mode of chatapp convert command --- api/commands.py | 72 ++++++++++++++++++- api/core/application_manager.py | 41 ++++++----- api/core/entities/application_entities.py | 2 +- api/services/workflow/workflow_converter.py | 23 +++--- api/services/workflow_service.py | 2 +- .../workflow/test_workflow_converter.py | 2 + 6 files changed, 114 insertions(+), 28 deletions(-) diff --git a/api/commands.py b/api/commands.py index 9a023b1c48..73d2150de2 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,5 +1,6 @@ import base64 import json +import logging import secrets import click @@ -12,11 +13,12 @@ from extensions.ext_database import db from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant +from models.account import Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from services.workflow.workflow_converter import WorkflowConverter @click.command('reset-password', help='Reset the account password.') @@ -422,9 +424,77 @@ and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000 click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) +@click.command('convert-to-workflow-chatbot-apps', help='Convert Basic Export Assistant to Chatbot Workflow App.') +def convert_to_workflow_chatbot_apps(): + """ + Convert Basic Export Assistant to Chatbot Workflow App. + """ + click.echo(click.style('Start convert to workflow chatbot apps.', fg='green')) + + proceeded_app_ids = [] + workflow_converter = WorkflowConverter() + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id FROM apps a +LEFT JOIN app_model_configs am ON a.app_model_config_id=am.id +WHERE a.mode = 'chat' AND am.prompt_type='advanced' ORDER BY a.created_at DESC LIMIT 1000""" + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + print(app_id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.query(App).filter(App.id == app_id).first() + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo('Converting app: {}'.format(app.id)) + + try: + # get workspace of app + tenant = db.session.query(Tenant).filter(Tenant.id == app.tenant_id).first() + if not tenant: + click.echo(click.style('Tenant not found: {}'.format(app.tenant_id), fg='red')) + continue + + # get workspace owner + tenant_account_join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.role == 'owner' + ).first() + + if not tenant_account_join: + click.echo(click.style('Tenant owner not found: {}'.format(tenant.id), fg='red')) + continue + + # convert to workflow + workflow_converter.convert_to_workflow( + app_model=app, + account_id=tenant_account_join.account_id + ) + + click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + except Exception as e: + logging.exception('Convert app error: {}'.format(app.id)) + click.echo( + click.style('Convert app error: {} {}'.format(e.__class__.__name__, + str(e)), fg='red')) + + click.echo(click.style('Congratulations! Converted {} workflow chatbot apps.'.format(len(proceeded_app_ids)), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) app.cli.add_command(convert_to_agent_apps) + app.cli.add_command(convert_to_workflow_chatbot_apps) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 77bb81b0da..ea0c85427d 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -235,12 +235,15 @@ class ApplicationManager: logger.exception(e) raise e - def convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ + def convert_from_app_model_config_dict(self, tenant_id: str, + app_model_config_dict: dict, + skip_check: bool = False) \ -> AppOrchestrationConfigEntity: """ Convert app model config dict to entity. :param tenant_id: tenant ID :param app_model_config_dict: app model config dict + :param skip_check: skip check :raises ProviderTokenNotInitError: provider token not init error :return: app orchestration config entity """ @@ -268,24 +271,28 @@ class ApplicationManager: ) if model_credentials is None: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + if not skip_check: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + else: + model_credentials = {} - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) + if not skip_check: + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=copy_app_model_config_dict['model']['name'], + model_type=ModelType.LLM + ) - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") + if provider_model is None: + model_name = copy_app_model_config_dict['model']['name'] + raise ValueError(f"Model {model_name} not exist.") - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = copy_app_model_config_dict['model'].get('completion_params') @@ -309,7 +316,7 @@ class ApplicationManager: model_credentials ) - if not model_schema: + if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") properties['model_config'] = ModelConfigEntity( diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 667940f184..f5ea4d1eb0 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -15,7 +15,7 @@ class ModelConfigEntity(BaseModel): """ provider: str model: str - model_schema: AIModelEntity + model_schema: Optional[AIModelEntity] = None mode: str provider_model_bundle: ProviderModelBundle credentials: dict[str, Any] = {} diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 31df58a583..1d3cbe2e0e 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,12 +13,11 @@ from core.entities.application_entities import ( ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.utils import helper +from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from extensions.ext_database import db -from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType @@ -29,7 +28,7 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + def convert_to_workflow(self, app_model: App, account_id: str) -> Workflow: """ Convert to workflow mode @@ -40,7 +39,7 @@ class WorkflowConverter: - completion app (for migration) :param app_model: App instance - :param account: Account instance + :param account_id: Account ID :return: workflow instance """ # get new app mode @@ -53,7 +52,8 @@ class WorkflowConverter: application_manager = ApplicationManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config.to_dict() + app_model_config_dict=app_model_config.to_dict(), + skip_check=True ) # init workflow graph @@ -122,7 +122,7 @@ class WorkflowConverter: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), - created_by=account.id + created_by=account_id ) db.session.add(workflow) @@ -130,6 +130,7 @@ class WorkflowConverter: # create new app model config record new_app_model_config = app_model_config.copy() + new_app_model_config.id = None new_app_model_config.external_data_tools = '' new_app_model_config.model = '' new_app_model_config.user_input_form = '' @@ -147,6 +148,9 @@ class WorkflowConverter: db.session.add(new_app_model_config) db.session.commit() + app_model.app_model_config_id = new_app_model_config.id + db.session.commit() + return workflow def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: @@ -161,7 +165,7 @@ class WorkflowConverter: "data": { "title": "START", "type": NodeType.START.value, - "variables": [helper.dump_model(v) for v in variables] + "variables": [jsonable_encoder(v) for v in variables] } } @@ -369,7 +373,10 @@ class WorkflowConverter: ] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template - prompts = [helper.dump_model(m) for m in advanced_chat_prompt_template.messages] \ + prompts = [{ + "role": m.role.value, + "text": m.text + } for m in advanced_chat_prompt_template.messages] \ if advanced_chat_prompt_template else [] # Completion Model else: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0cb398225d..bd88f3cbe2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -79,6 +79,6 @@ class WorkflowService: # convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model, account=account) + workflow = workflow_converter.convert_to_workflow(app_model=app_model, account_id=account.id) return workflow diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index ee9e5eb2fa..d4edc73410 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -41,6 +41,8 @@ def test__convert_to_start_node(default_variables): result = WorkflowConverter()._convert_to_start_node(default_variables) # assert + assert isinstance(result["data"]["variables"][0]["type"], str) + assert result["data"]["variables"][0]["type"] == "text-input" assert result["data"]["variables"][0]["variable"] == "text-input" assert result["data"]["variables"][1]["variable"] == "paragraph" assert result["data"]["variables"][2]["variable"] == "select" From 748aa22ee2e1deec036378d664bc7d6652886c4e Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:02:28 +0800 Subject: [PATCH 055/200] add manual convert logic --- api/commands.py | 81 +----------- api/controllers/console/app/workflow.py | 8 +- .../versions/b289e2408ee2_add_workflow.py | 2 + api/models/model.py | 1 + api/models/workflow.py | 78 +++++++++++ api/services/workflow/workflow_converter.py | 123 +++++++++++++----- api/services/workflow_service.py | 29 +++-- 7 files changed, 198 insertions(+), 124 deletions(-) diff --git a/api/commands.py b/api/commands.py index 73d2150de2..e376d222c6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,6 +1,5 @@ import base64 import json -import logging import secrets import click @@ -13,12 +12,11 @@ from extensions.ext_database import db from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant, TenantAccountJoin +from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel -from services.workflow.workflow_converter import WorkflowConverter @click.command('reset-password', help='Reset the account password.') @@ -384,10 +382,11 @@ def convert_to_agent_apps(): while True: # fetch first 1000 apps sql_query = """SELECT a.id AS id FROM apps a -INNER JOIN app_model_configs am ON a.app_model_config_id=am.id -WHERE a.mode = 'chat' AND am.agent_mode is not null -and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') -and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + INNER JOIN app_model_configs am ON a.app_model_config_id=am.id + WHERE a.mode = 'chat' AND am.agent_mode is not null + and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%') + and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000""" + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query)) @@ -424,77 +423,9 @@ and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000 click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) -@click.command('convert-to-workflow-chatbot-apps', help='Convert Basic Export Assistant to Chatbot Workflow App.') -def convert_to_workflow_chatbot_apps(): - """ - Convert Basic Export Assistant to Chatbot Workflow App. - """ - click.echo(click.style('Start convert to workflow chatbot apps.', fg='green')) - - proceeded_app_ids = [] - workflow_converter = WorkflowConverter() - - while True: - # fetch first 1000 apps - sql_query = """SELECT a.id FROM apps a -LEFT JOIN app_model_configs am ON a.app_model_config_id=am.id -WHERE a.mode = 'chat' AND am.prompt_type='advanced' ORDER BY a.created_at DESC LIMIT 1000""" - - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query)) - - apps = [] - for i in rs: - app_id = str(i.id) - print(app_id) - if app_id not in proceeded_app_ids: - proceeded_app_ids.append(app_id) - app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) - - if len(apps) == 0: - break - - for app in apps: - click.echo('Converting app: {}'.format(app.id)) - - try: - # get workspace of app - tenant = db.session.query(Tenant).filter(Tenant.id == app.tenant_id).first() - if not tenant: - click.echo(click.style('Tenant not found: {}'.format(app.tenant_id), fg='red')) - continue - - # get workspace owner - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role == 'owner' - ).first() - - if not tenant_account_join: - click.echo(click.style('Tenant owner not found: {}'.format(tenant.id), fg='red')) - continue - - # convert to workflow - workflow_converter.convert_to_workflow( - app_model=app, - account_id=tenant_account_join.account_id - ) - - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) - except Exception as e: - logging.exception('Convert app error: {}'.format(app.id)) - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) - - click.echo(click.style('Congratulations! Converted {} workflow chatbot apps.'.format(len(proceeded_app_ids)), fg='green')) - - def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(vdb_migrate) app.cli.add_command(convert_to_agent_apps) - app.cli.add_command(convert_to_workflow_chatbot_apps) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 7663e22580..dc1b7edcaf 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -69,15 +69,15 @@ class ConvertToWorkflowApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) - @marshal_with(workflow_fields) + @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model: App): """ - Convert basic mode of chatbot app to workflow + Convert basic mode of chatbot app(expert mode) to workflow mode + Convert Completion App to Workflow App """ # convert to workflow mode workflow_service = WorkflowService() - workflow = workflow_service.chatbot_convert_to_workflow( + workflow = workflow_service.convert_to_workflow( app_model=app_model, account=current_user ) diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index e9cd2caf3a..9e04fef288 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -53,6 +53,7 @@ def upgrade(): sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), sa.Column('execution_metadata', sa.Text(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('finished_at', sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey') @@ -80,6 +81,7 @@ def upgrade(): sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), sa.Column('currency', sa.String(length=255), nullable=True), sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('finished_at', sa.DateTime(), nullable=True), diff --git a/api/models/model.py b/api/models/model.py index 58e29cd21c..1e66fd6c88 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -28,6 +28,7 @@ class DifySetup(db.Model): class AppMode(Enum): + COMPLETION = 'completion' WORKFLOW = 'workflow' CHAT = 'chat' AGENT = 'agent' diff --git a/api/models/workflow.py b/api/models/workflow.py index 95805e7871..251f33b0c0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,6 +7,27 @@ from extensions.ext_database import db from models.account import Account +class CreatedByRole(Enum): + """ + Created By Role Enum + """ + ACCOUNT = 'account' + END_USER = 'end_user' + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid created by role value {value}') + + class WorkflowType(Enum): """ Workflow Type Enum @@ -99,6 +120,49 @@ class Workflow(db.Model): return Account.query.get(self.updated_by) +class WorkflowRunTriggeredFrom(Enum): + """ + Workflow Run Triggered From Enum + """ + DEBUGGING = 'debugging' + APP_RUN = 'app-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run triggered from value {value}') + + +class WorkflowRunStatus(Enum): + """ + Workflow Run Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowRunStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow run status value {value}') + + class WorkflowRun(db.Model): """ Workflow Run @@ -128,6 +192,12 @@ class WorkflowRun(db.Model): - total_price (decimal) `optional` Total cost - currency (string) `optional` Currency, such as USD / RMB - total_steps (int) Total steps (redundant), default 0 + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + - created_by (uuid) Runner ID - created_at (timestamp) Run time - finished_at (timestamp) End time @@ -157,6 +227,7 @@ class WorkflowRun(db.Model): total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255)) total_steps = db.Column(db.Integer, server_default=db.text('0')) + created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) @@ -208,6 +279,12 @@ class WorkflowNodeExecution(db.Model): - currency (string) `optional` Currency, such as USD / RMB - created_at (timestamp) Run time + - created_by_role (string) Creator role + + - `account` Console account + + - `end_user` End user + - created_by (uuid) Runner ID - finished_at (timestamp) End time """ @@ -240,6 +317,7 @@ class WorkflowNodeExecution(db.Model): elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) execution_metadata = db.Column(db.Text) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) finished_at = db.Column(db.DateTime) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1d3cbe2e0e..bb300d1a77 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -17,9 +17,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.NodeEntities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType +from events.app_event import app_was_created from extensions.ext_database import db +from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode, ChatbotAppEngine, AppModelConfig, Site from models.workflow import Workflow, WorkflowType @@ -28,26 +30,99 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, account_id: str) -> Workflow: + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ - Convert to workflow mode + Convert app to workflow - basic mode of chatbot app - - advanced mode of assistant app (for migration) + - advanced mode of assistant app - - completion app (for migration) + - completion app :param app_model: App instance + :param account: Account + :return: new App instance + """ + # get original app config + app_model_config = app_model.app_model_config + + # convert app model config + workflow = self.convert_app_model_config_to_workflow( + app_model=app_model, + app_model_config=app_model_config, + account_id=account.id + ) + + # create new app + new_app = App() + new_app.tenant_id = app_model.tenant_id + new_app.name = app_model.name + '(workflow)' + new_app.mode = AppMode.CHAT.value \ + if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon = app_model.icon + new_app.icon_background = app_model.icon_background + new_app.enable_site = app_model.enable_site + new_app.enable_api = app_model.enable_api + new_app.api_rpm = app_model.api_rpm + new_app.api_rph = app_model.api_rph + new_app.is_demo = False + new_app.is_public = app_model.is_public + db.session.add(new_app) + db.session.flush() + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.id = None + new_app_model_config.app_id = new_app.id + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ + if app_model.mode == AppMode.CHAT.value else ChatbotAppEngine.NORMAL.value + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.flush() + + new_app.app_model_config_id = new_app_model_config.id + db.session.commit() + + site = Site( + app_id=new_app.id, + title=new_app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() + + app_was_created.send(new_app) + + return new_app + + def convert_app_model_config_to_workflow(self, app_model: App, + app_model_config: AppModelConfig, + account_id: str) -> Workflow: + """ + Convert app model config to workflow mode + :param app_model: App instance + :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: workflow instance + :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) - # get original app config - app_model_config = app_model.app_model_config - # convert app model config application_manager = ApplicationManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( @@ -122,33 +197,11 @@ class WorkflowConverter: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), - created_by=account_id + created_by=account_id, + created_at=app_model_config.created_at ) db.session.add(workflow) - db.session.flush() - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if new_app_mode == AppMode.CHAT else ChatbotAppEngine.NORMAL.value - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.commit() - - app_model.app_model_config_id = new_app_model_config.id db.session.commit() return workflow @@ -469,7 +522,7 @@ class WorkflowConverter: "type": NodeType.END.value, } } - elif app_model.mode == "completion": + elif app_model.mode == AppMode.COMPLETION.value: # for original completion app return { "id": "end", @@ -516,7 +569,7 @@ class WorkflowConverter: :param app_model: App instance :return: AppMode """ - if app_model.mode == "completion": + if app_model.mode == AppMode.COMPLETION.value: return AppMode.WORKFLOW else: return AppMode.value_of(app_model.mode) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index bd88f3cbe2..2d9342ffc9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ from datetime import datetime from extensions.ext_database import db from models.account import Account -from models.model import App, ChatbotAppEngine +from models.model import App, ChatbotAppEngine, AppMode from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter @@ -65,20 +65,29 @@ class WorkflowService: # return default block config return default_block_configs - def chatbot_convert_to_workflow(self, app_model: App, account: Account) -> Workflow: + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ - basic mode of chatbot app to workflow + Basic mode of chatbot app(expert mode) to workflow + Completion App to Workflow App :param app_model: App instance :param account: Account instance :return: """ - # check if chatbot app is in basic mode - if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: - raise ValueError('Chatbot app already in workflow mode') - - # convert to workflow mode + # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - workflow = workflow_converter.convert_to_workflow(app_model=app_model, account_id=account.id) - return workflow + if app_model.mode == AppMode.CHAT.value: + # check if chatbot app is in basic mode + if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: + raise ValueError('Chatbot app already in workflow mode') + elif app_model.mode != AppMode.COMPLETION.value: + raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + + # convert to workflow + new_app = workflow_converter.convert_to_workflow( + app_model=app_model, + account=account + ) + + return new_app From 97c4733e7928b09b33e18c5f3f54856890c78c1f Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:02:38 +0800 Subject: [PATCH 056/200] lint fix --- api/services/workflow/workflow_converter.py | 2 +- api/services/workflow_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index bb300d1a77..c6f0bed008 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, ChatbotAppEngine, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, ChatbotAppEngine, Site from models.workflow import Workflow, WorkflowType diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2d9342ffc9..4f7262b7d6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ from datetime import datetime from extensions.ext_database import db from models.account import Account -from models.model import App, ChatbotAppEngine, AppMode +from models.model import App, AppMode, ChatbotAppEngine from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter From fce20e483cf4cc4eadd8f3386f4478ac5a50bbfd Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:30:36 +0800 Subject: [PATCH 057/200] restore completion app --- api/controllers/console/app/app.py | 2 +- api/controllers/console/app/completion.py | 4 +- api/controllers/console/app/conversation.py | 4 +- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/explore/message.py | 47 +++++++++++++++ api/controllers/web/message.py | 47 +++++++++++++++ api/core/app_runner/app_runner.py | 19 ++++-- api/core/prompt/prompt_transform.py | 7 +-- api/core/prompt/simple_prompt_transform.py | 38 +++++++----- api/services/app_model_config_service.py | 18 ++++++ api/services/completion_service.py | 60 ++++++++++++++++++- api/services/errors/__init__.py | 2 +- api/services/errors/app.py | 2 + .../prompt/test_simple_prompt_transform.py | 2 + 14 files changed, 224 insertions(+), 30 deletions(-) create mode 100644 api/services/errors/app.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index cf505bedb8..93dc1ca34a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -80,7 +80,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') + parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 11fdba177d..e62475308f 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -37,7 +37,7 @@ class CompletionMessageApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') @@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index daf9641121..b808d62eb0 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -29,7 +29,7 @@ class CompletionConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ea4d597112..e3a5112200 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.WORKFLOW) + @get_app_model(mode=AppMode.COMPLETION) def get(self, app_model): account = current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index bef26b4d99..47af28425f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, CompletionRequestError, ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, @@ -23,10 +24,13 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs.helper import uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource): return {'result': 'success'} +class MessageMoreLikeThisApi(InstalledAppResource): + def get(self, installed_app, message_id): + app_model = installed_app.app + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=current_user, + message_id=message_id, + invoke_from=InvokeFrom.EXPLORE, + streaming=streaming + ) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource): api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') +api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 5120f49c5e..e03bdd63bb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api from controllers.web.error import ( + AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, CompletionRequestError, NotChatAppError, @@ -20,11 +21,14 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource +from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields from libs.helper import TimestampField, uuid_value +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService @@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource): return {'result': 'success'} +class MessageMoreLikeThisApi(WebApiResource): + def get(self, app_model, end_user, message_id): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this( + app_model=app_model, + user=end_user, + message_id=message_id, + invoke_from=InvokeFrom.WEB_APP, + streaming=streaming + ) + + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + def compact_response(response: Union[dict, Generator]) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype='application/json') @@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource): api.add_resource(MessageListApi, '/messages') api.add_resource(MessageFeedbackApi, '/messages//feedbacks') +api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index c6f6268a7a..231530ef08 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import App, Message, MessageAnnotation +from models.model import App, Message, MessageAnnotation, AppMode class AppRunner: @@ -140,11 +141,11 @@ class AppRunner: :param memory: memory :return: """ - prompt_transform = SimplePromptTransform() - # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( + app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, query=query if query else '', @@ -154,7 +155,17 @@ class AppRunner: model_config=model_config ) else: - raise NotImplementedError("Advanced prompt is not supported yet.") + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query if query else '', + files=files, + context=context, + memory=memory, + model_config=model_config + ) + stop = model_config.stop return prompt_messages, stop diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9596976b6e..9c554140b7 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -11,10 +11,9 @@ class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> list[PromptMessage]: - if memory: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) - prompt_messages.extend(histories) + rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + prompt_messages.extend(histories) return prompt_messages diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 2f98fbcae8..a929416be4 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform): """ def get_prompt(self, + app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, @@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform): model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform): ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( + app_mode=app_mode, pre_prompt=prompt_template_entity.simple_prompt_template, inputs=inputs, query=query, @@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform): "prompt_rules": prompt_rules } - def _get_chat_model_prompt_messages(self, pre_prompt: str, + def _get_chat_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform): # get prompt prompt, _ = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform): ) if prompt: - prompt_messages.append(SystemPromptMessage(content=prompt)) + if query: + prompt_messages.append(SystemPromptMessage(content=prompt)) + else: + prompt_messages.append(UserPromptMessage(content=prompt)) - prompt_messages = self._append_chat_histories( - memory=memory, - prompt_messages=prompt_messages, - model_config=model_config - ) + if memory: + prompt_messages = self._append_chat_histories( + memory=memory, + prompt_messages=prompt_messages, + model_config=model_config + ) - prompt_messages.append(self.get_last_user_message(query, files)) + if query: + prompt_messages.append(self.get_last_user_message(query, files)) return prompt_messages, None - def _get_completion_model_prompt_messages(self, pre_prompt: str, + def _get_completion_model_prompt_messages(self, app_mode: AppMode, + pre_prompt: str, inputs: dict, query: str, context: Optional[str], @@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform): -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform): # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( - app_mode=AppMode.CHAT, + app_mode=app_mode, model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, @@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform): is_baichuan = True if is_baichuan: - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'baichuan_completion' else: return 'baichuan_chat' # common - if app_mode == AppMode.WORKFLOW: + if app_mode == AppMode.COMPLETION: return 'common_completion' else: return 'common_chat' diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index aa8cd73ea7..34b6d62d51 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -316,6 +316,9 @@ class AppModelConfigService: if "tool_parameters" not in tool: raise ValueError("tool_parameters is required in agent_mode.tools") + # dataset_query_variable + cls.is_dataset_query_variable_valid(config, app_mode) + # advanced prompt validation cls.is_advanced_prompt_valid(config, app_mode) @@ -441,6 +444,21 @@ class AppModelConfigService: config=config ) + @classmethod + def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: + # Only check when mode is completion + if mode != 'completion': + return + + agent_mode = config.get("agent_mode", {}) + tools = agent_mode.get("tools", []) + dataset_exists = "dataset" in str(tools) + + dataset_query_variable = config.get("dataset_query_variable") + + if dataset_exists and not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + @classmethod def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: # prompt_type diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 5599c60113..cbfbe9ef41 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService +from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.message import MessageNotExistsError class CompletionService: @@ -155,6 +157,62 @@ class CompletionService: } ) + @classmethod + def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], + message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ + -> Union[dict, Generator]: + if not user: + raise ValueError('user cannot be None') + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + model_dict = app_model_config.model_dict + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + app_model_config.model = json.dumps(model_dict) + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_objs = message_file_parser.transform_message_files( + message.files, app_model_config + ) + + application_manager = ApplicationManager() + return application_manager.generate( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_model_config_id=app_model_config.id, + app_model_config_dict=app_model_config.to_dict(), + app_model_config_override=True, + user=user, + invoke_from=invoke_from, + inputs=message.inputs, + query=message.query, + files=file_objs, + conversation=None, + stream=streaming, + extras={ + "auto_generate_conversation_name": False + } + ) + @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index a44c190cbc..5804f599fe 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'completion', 'audio', 'file' + 'app', 'completion', 'audio', 'file' ] from . import * diff --git a/api/services/errors/app.py b/api/services/errors/app.py new file mode 100644 index 0000000000..7c4ca99c2a --- /dev/null +++ b/api/services/errors/app.py @@ -0,0 +1,2 @@ +class MoreLikeThisDisabledError(Exception): + pass diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c174983e38..a95a6dc52f 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query, @@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages(): context = "yes or no." query = "How are you?" prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, pre_prompt=pre_prompt, inputs=inputs, query=query, From 98cb17e79e7c5bf827292889ed8f496b7362453a Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:30:44 +0800 Subject: [PATCH 058/200] lint fix --- api/core/app_runner/app_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index 231530ef08..95f2f568dc 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -24,7 +24,7 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import App, Message, MessageAnnotation, AppMode +from models.model import App, AppMode, Message, MessageAnnotation class AppRunner: From 34ed5e428cdf2f116156033e3ae3dfa33b53651a Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 21:55:39 +0800 Subject: [PATCH 059/200] fix bugs --- api/core/prompt/advanced_prompt_transform.py | 34 +++++++++++++------ .../prompt/test_advanced_prompt_transform.py | 1 + 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0ed9ec352c..7519971ce7 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -39,6 +39,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = self._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=query, files=files, context=context, memory=memory, @@ -60,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform): def _get_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -86,6 +88,9 @@ class AdvancedPromptTransform(PromptTransform): model_config=model_config ) + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt = prompt_template.format( prompt_inputs ) @@ -147,21 +152,30 @@ class AdvancedPromptTransform(PromptTransform): else: prompt_messages.append(UserPromptMessage(content=query)) elif files: - # get last message - last_message = prompt_messages[-1] if prompt_messages else None - if last_message and last_message.role == PromptMessageRole.USER: - # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) + if not query: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) - last_message.content = prompt_message_contents + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: prompt_message_contents.append(file.prompt_message_content) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + elif query: + prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages @@ -210,4 +224,4 @@ class AdvancedPromptTransform(PromptTransform): else: prompt_inputs['#histories#'] = '' - return prompt_inputs + return prompt_inputs diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 65a160a8e5..95f1e30b44 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages(): prompt_messages = prompt_transform._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=None, files=files, context=context, memory=memory, From 77f04603b3633c809e03d6e4b7b4d79d18d6ce59 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 22:11:20 +0800 Subject: [PATCH 060/200] fix bugs --- api/core/prompt/simple_prompt_transform.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a929416be4..fcae0dc786 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -178,11 +178,8 @@ class SimplePromptTransform(PromptTransform): context=context ) - if prompt: - if query: - prompt_messages.append(SystemPromptMessage(content=prompt)) - else: - prompt_messages.append(UserPromptMessage(content=prompt)) + if prompt and query: + prompt_messages.append(SystemPromptMessage(content=prompt)) if memory: prompt_messages = self._append_chat_histories( @@ -193,6 +190,8 @@ class SimplePromptTransform(PromptTransform): if query: prompt_messages.append(self.get_last_user_message(query, files)) + else: + prompt_messages.append(self.get_last_user_message(prompt, files)) return prompt_messages, None From a9192bc1c63352fbf3134100ca9db355fa02dbe0 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 26 Feb 2024 12:43:46 +0800 Subject: [PATCH 061/200] make recommended app list api public --- .../console/explore/recommended_app.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index fd90be03b1..8b8fe349ed 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, fields, marshal_with +from flask_restful import Resource, fields, marshal_with, reqparse from sqlalchemy import and_ from constants.languages import languages @@ -28,9 +28,6 @@ recommended_app_fields = { 'category': fields.String, 'position': fields.Integer, 'is_listed': fields.Boolean, - 'install_count': fields.Integer, - 'installed': fields.Boolean, - 'editable': fields.Boolean, 'is_agent': fields.Boolean } @@ -41,11 +38,19 @@ recommended_app_list_fields = { class RecommendedAppListApi(Resource): - @login_required - @account_initialization_required @marshal_with(recommended_app_list_fields) def get(self): - language_prefix = current_user.interface_language if current_user.interface_language else languages[0] + # language args + parser = reqparse.RequestParser() + parser.add_argument('language', type=str, location='args') + args = parser.parse_args() + + if args.get('language') and args.get('language') in languages: + language_prefix = args.get('language') + elif current_user and current_user.interface_language: + language_prefix = current_user.interface_language + else: + language_prefix = languages[0] recommended_apps = db.session.query(RecommendedApp).filter( RecommendedApp.is_listed == True, @@ -53,16 +58,8 @@ class RecommendedAppListApi(Resource): ).all() categories = set() - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) recommended_apps_result = [] for recommended_app in recommended_apps: - installed = db.session.query(InstalledApp).filter( - and_( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id == current_user.current_tenant_id - ) - ).first() is not None - app = recommended_app.app if not app or not app.is_public: continue @@ -81,9 +78,6 @@ class RecommendedAppListApi(Resource): 'category': recommended_app.category, 'position': recommended_app.position, 'is_listed': recommended_app.is_listed, - 'install_count': recommended_app.install_count, - 'installed': installed, - 'editable': current_user.role in ['owner', 'admin'], "is_agent": app.is_agent } recommended_apps_result.append(recommended_app_result) @@ -114,8 +108,6 @@ class RecommendedAppApi(Resource): 'app_model_config': fields.Nested(model_config_fields), } - @login_required - @account_initialization_required @marshal_with(app_simple_detail_fields) def get(self, app_id): app_id = str(app_id) From 78afba49bf336542bec774bc9b859d57c4556f7a Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 26 Feb 2024 12:44:21 +0800 Subject: [PATCH 062/200] lint fix --- api/controllers/console/explore/recommended_app.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 8b8fe349ed..6ba04d603a 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,11 @@ from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse -from sqlalchemy import and_ from constants.languages import languages from controllers.console import api from controllers.console.app.error import AppNotFoundError -from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp -from services.account_service import TenantService +from models.model import App, RecommendedApp app_fields = { 'id': fields.String, From 27ba5a0bce66879969e6da9d2554b86815fdcb76 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:23:01 +0800 Subject: [PATCH 063/200] refactor app mode add app import and export --- api/constants/languages.py | 509 ------------------ api/constants/model_template.py | 99 ++-- api/controllers/console/app/app.py | 239 +++++--- api/controllers/console/app/workflow.py | 11 +- api/controllers/console/app/wraps.py | 18 +- .../console/explore/installed_app.py | 3 +- .../console/explore/recommended_app.py | 64 ++- api/core/provider_manager.py | 2 +- api/fields/app_fields.py | 12 - api/fields/installed_app_fields.py | 3 +- .../versions/b289e2408ee2_add_workflow.py | 2 - ...998d4d_set_model_config_column_nullable.py | 70 +++ api/models/model.py | 53 +- api/services/workflow/workflow_converter.py | 4 +- api/services/workflow_service.py | 43 +- 15 files changed, 371 insertions(+), 761 deletions(-) create mode 100644 api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py diff --git a/api/constants/languages.py b/api/constants/languages.py index 0ae69d77d2..0147dd8d70 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -91,512 +91,3 @@ user_input_form_template = { } ], } - -demo_model_templates = { - 'en-US': [ - { - 'name': 'Translation Assistant', - 'icon': '', - 'icon_background': '', - 'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Please translate the following text into {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Target Language", - "description": "The language you want to translate into.", - "type": "select", - "default": "Chinese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Target Language", - "variable": "target_language", - "description": "The language you want to translate into.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI Front-end Interviewer', - 'icon': '', - 'icon_background': '', - 'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - 'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', - suggested_questions=None, - pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'zh-Hans': [ - { - 'name': '翻译助手', - 'icon': '', - 'icon_background': '', - 'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "请将以下文本翻译为{{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "目标语言", - "description": "翻译的目标语言", - "type": "select", - "default": "中文", - "options": [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "目标语言", - "variable": "target_language", - "description": "翻译的目标语言", - "default": "中文", - "required": True, - 'options': [ - "中文", - "英文", - "日语", - "法语", - "俄语", - "德语", - "西班牙语", - "韩语", - "意大利语", - ] - } - }, { - "paragraph": { - "label": "文本内容", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'AI 前端面试官', - 'icon': '', - 'icon_background': '', - 'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - 'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', - suggested_questions=None, - pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], - 'uk-UA': [ - { - "name": "Помічник перекладу", - "icon": "", - "icon_background": "", - "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", - "mode": "completion", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo-instruct", - configs={ - "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", - "prompt_variables": [ - { - "key": "target_language", - "name": "Цільова мова", - "description": "Мова, на яку ви хочете перекласти.", - "type": "select", - "default": "Ukrainian", - "options": [ - "Chinese", - "English", - "Japanese", - "French", - "Russian", - "German", - "Spanish", - "Korean", - "Italian", - ], - }, - ], - "completion_params": { - "max_token": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="", - suggested_questions=None, - pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Цільова мова", - "variable": "target_language", - "description": "Мова, на яку ви хочете перекласти.", - "default": "Chinese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - ] - } - }, { - "paragraph": { - "label": "Запит", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - "name": "AI інтерв’юер фронтенду", - "icon": "", - "icon_background": "", - "description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.", - "mode": "chat", - "model_config": AppModelConfig( - provider="openai", - model_id="gpt-3.5-turbo", - configs={ - "introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - "prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - "prompt_variables": [], - "completion_params": { - "max_token": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }, - opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", - suggested_questions=None, - pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1, - }, - }), - user_input_form=None - ), - } - ], - 'vi-VN': [ - { - 'name': 'Trợ lý dịch thuật', - 'icon': '', - 'icon_background': '', - 'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.', - 'mode': 'completion', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo-instruct', - configs={ - 'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n", - 'prompt_variables': [ - { - "key": "target_language", - "name": "Ngôn ngữ đích", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "type": "select", - "default": "Vietnamese", - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - ], - 'completion_params': { - 'max_token': 1000, - 'temperature': 0, - 'top_p': 0, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='', - suggested_questions=None, - pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-instruct", - "mode": "completion", - "completion_params": { - "max_tokens": 1000, - "temperature": 0, - "top_p": 0, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=json.dumps([ - { - "select": { - "label": "Ngôn ngữ đích", - "variable": "target_language", - "description": "Ngôn ngữ bạn muốn dịch sang.", - "default": "Vietnamese", - "required": True, - 'options': [ - 'Chinese', - 'English', - 'Japanese', - 'French', - 'Russian', - 'German', - 'Spanish', - 'Korean', - 'Italian', - 'Vietnamese', - ] - } - }, { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]) - ) - }, - { - 'name': 'Phỏng vấn front-end AI', - 'icon': '', - 'icon_background': '', - 'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.', - 'mode': 'chat', - 'model_config': AppModelConfig( - provider='openai', - model_id='gpt-3.5-turbo', - configs={ - 'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - 'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - 'prompt_variables': [], - 'completion_params': { - 'max_token': 300, - 'temperature': 0.8, - 'top_p': 0.9, - 'presence_penalty': 0.1, - 'frequency_penalty': 0.1, - } - }, - opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ', - suggested_questions=None, - pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n", - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": { - "max_tokens": 300, - "temperature": 0.8, - "top_p": 0.9, - "presence_penalty": 0.1, - "frequency_penalty": 0.1 - } - }), - user_input_form=None - ) - } - ], -} diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c22306ac87..ca0b754989 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,50 +1,25 @@ -import json +from models.model import AppMode -model_templates = { +default_app_templates = { # workflow default mode - 'workflow_default': { + AppMode.WORKFLOW: { 'app': { - 'mode': 'workflow', + 'mode': AppMode.WORKFLOW.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, - 'model_config': { - 'provider': '', - 'model_id': '', - 'configs': {} - } + 'model_config': {} }, # chat default mode - 'chat_default': { + AppMode.CHAT: { 'app': { - 'mode': 'chat', + 'mode': AppMode.CHAT.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, 'model_config': { - 'provider': 'openai', - 'model_id': 'gpt-4', - 'configs': { - 'prompt_template': '', - 'prompt_variables': [], - 'completion_params': { - 'max_token': 512, - 'temperature': 1, - 'top_p': 1, - 'presence_penalty': 0, - 'frequency_penalty': 0, - } - }, - 'model': json.dumps({ + 'model': { "provider": "openai", "name": "gpt-4", "mode": "chat", @@ -55,36 +30,19 @@ model_templates = { "presence_penalty": 0, "frequency_penalty": 0 } - }) + } } }, - # agent default mode - 'agent_default': { + # advanced-chat default mode + AppMode.ADVANCED_CHAT: { 'app': { - 'mode': 'agent', + 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, - 'enable_api': True, - 'is_demo': False, - 'api_rpm': 0, - 'api_rph': 0, - 'status': 'normal' + 'enable_api': True }, 'model_config': { - 'provider': 'openai', - 'model_id': 'gpt-4', - 'configs': { - 'prompt_template': '', - 'prompt_variables': [], - 'completion_params': { - 'max_token': 512, - 'temperature': 1, - 'top_p': 1, - 'presence_penalty': 0, - 'frequency_penalty': 0, - } - }, - 'model': json.dumps({ + 'model': { "provider": "openai", "name": "gpt-4", "mode": "chat", @@ -95,7 +53,30 @@ model_templates = { "presence_penalty": 0, "frequency_penalty": 0 } - }) + } + } + }, + + # agent-chat default mode + AppMode.AGENT_CHAT: { + 'app': { + 'mode': AppMode.AGENT_CHAT.value, + 'enable_site': True, + 'enable_api': True + }, + 'model_config': { + 'model': { + "provider": "openai", + "name": "gpt-4", + "mode": "chat", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + } } }, } diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 93dc1ca34a..4c218bef1b 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,13 +1,15 @@ import json import logging from datetime import datetime +from typing import cast +import yaml from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from constants.languages import demo_model_templates, languages -from constants.model_template import model_templates +from constants.languages import languages +from constants.model_template import default_app_templates from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.wraps import get_app_model @@ -15,7 +17,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db @@ -28,10 +31,15 @@ from fields.app_fields import ( from libs.login import login_required from models.model import App, AppModelConfig, Site, AppMode from services.app_model_config_service import AppModelConfigService +from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager from core.entities.application_entities import AgentToolEntity + +ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] + + class AppListApi(Resource): @setup_required @@ -43,7 +51,7 @@ class AppListApi(Resource): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False) + parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent', 'channel', 'all'], default='all', location='args', required=False) parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() @@ -52,15 +60,20 @@ class AppListApi(Resource): App.is_universal == False ] - if args['mode'] == 'completion': - filters.append(App.mode == 'completion') + if args['mode'] == 'workflow': + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) elif args['mode'] == 'chat': - filters.append(App.mode == 'chat') + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args['mode'] == 'agent': + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args['mode'] == 'channel': + filters.append(App.mode == AppMode.CHANNEL.value) else: pass if 'name' in args and args['name']: - filters.append(App.name.ilike(f'%{args["name"]}%')) + name = args['name'][:30] + filters.append(App.name.ilike(f'%{name}%')) app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), @@ -80,10 +93,9 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json') + parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('model_config', type=dict, location='json') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner @@ -141,15 +153,15 @@ class AppListApi(Resource): app_mode = AppMode.value_of(args['mode']) - model_config_template = model_templates[app_mode.value + '_default'] + app_template = default_app_templates[app_mode] - app = App(**model_config_template['app']) - app_model_config = AppModelConfig(**model_config_template['model_config']) - - if app_mode in [AppMode.CHAT, AppMode.AGENT]: + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: # get model provider model_manager = ModelManager() + # get default model instance try: model_instance = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, @@ -159,10 +171,25 @@ class AppListApi(Resource): model_instance = None if model_instance: - model_dict = app_model_config.model_dict - model_dict['provider'] = model_instance.provider - model_dict['name'] = model_instance.model - app_model_config.model = json.dumps(model_dict) + if model_instance.model == default_model_config['model']['name']: + default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] + + default_model_config['model'] = json.dumps(default_model_dict) + + app = App(**app_template['app']) + app_model_config = AppModelConfig(**default_model_config) app.name = args['name'] app.mode = args['mode'] @@ -195,24 +222,95 @@ class AppListApi(Resource): app_was_created.send(app) return app, 201 - -class AppTemplateApi(Resource): +class AppImportApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(template_list_fields) - def get(self): - """Get app demo templates""" + @marshal_with(app_detail_fields) + @cloud_edition_billing_resource_check('apps') + def post(self): + """Import app""" + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument('name', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + try: + import_data = yaml.safe_load(args['data']) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML format in data argument.") + + app_data = import_data.get('app') + model_config_data = import_data.get('model_config') + workflow_graph = import_data.get('workflow_graph') + + if not app_data or not model_config_data: + raise ValueError("Missing app or model_config in data argument") + + app_mode = AppMode.value_of(app_data.get('mode')) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if not workflow_graph: + raise ValueError("Missing workflow_graph in data argument " + "when mode is advanced-chat or workflow") + + app = App( + enable_site=True, + enable_api=True, + is_demo=False, + api_rpm=0, + api_rph=0, + status='normal' + ) + + app.tenant_id = current_user.current_tenant_id + app.mode = app_data.get('mode') + app.name = args.get("name") if args.get("name") else app_data.get('name') + app.icon = args.get("icon") if args.get("icon") else app_data.get('icon') + app.icon_background = args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background') + + db.session.add(app) + db.session.commit() + + if workflow_graph: + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, current_user) + published_workflow = workflow_service.publish_draft_workflow(app, current_user, draft_workflow) + model_config_data['workflow_id'] = published_workflow.id + + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + account = current_user - interface_language = account.interface_language - templates = demo_model_templates.get(interface_language) - if not templates: - templates = demo_model_templates.get(languages[0]) + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) - return {'data': templates} + db.session.add(site) + db.session.commit() + + app_was_created.send(app) + + return app, 201 class AppApi(Resource): @@ -283,6 +381,38 @@ class AppApi(Resource): return {'result': 'success'}, 204 +class AppExportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + """Export app""" + app_model_config = app_model.app_model_config + + export_data = { + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": app_model.icon, + "icon_background": app_model.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + # get draft workflow + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return { + "data": yaml.dump(export_data) + } + + class AppNameApi(Resource): @setup_required @login_required @@ -360,57 +490,10 @@ class AppApiStatus(Resource): return app_model -class AppCopy(Resource): - @staticmethod - def create_app_copy(app): - copy_app = App( - name=app.name + ' copy', - icon=app.icon, - icon_background=app.icon_background, - tenant_id=app.tenant_id, - mode=app.mode, - app_model_config_id=app.app_model_config_id, - enable_site=app.enable_site, - enable_api=app.enable_api, - api_rpm=app.api_rpm, - api_rph=app.api_rph - ) - return copy_app - - @staticmethod - def create_app_model_config_copy(app_config, copy_app_id): - copy_app_model_config = app_config.copy() - copy_app_model_config.app_id = copy_app_id - - return copy_app_model_config - - @setup_required - @login_required - @account_initialization_required - @get_app_model - @marshal_with(app_detail_fields) - def post(self, app_model): - copy_app = self.create_app_copy(app_model) - db.session.add(copy_app) - - app_config = db.session.query(AppModelConfig). \ - filter(AppModelConfig.app_id == app_model.id). \ - one_or_none() - - if app_config: - copy_app_model_config = self.create_app_model_config_copy(app_config, copy_app.id) - db.session.add(copy_app_model_config) - db.session.commit() - copy_app.app_model_config_id = copy_app_model_config.id - db.session.commit() - - return copy_app, 201 - - api.add_resource(AppListApi, '/apps') -api.add_resource(AppTemplateApi, '/app-templates') +api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopy, '/apps//copy') +api.add_resource(AppExportApi, '/apps//export') api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppIconApi, '/apps//icon') api.add_resource(AppSiteStatus, '/apps//site-enable') diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index dc1b7edcaf..6023d0ba45 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_fields import workflow_fields from libs.login import current_user, login_required -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode from services.workflow_service import WorkflowService @@ -15,7 +15,7 @@ class DraftWorkflowApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @marshal_with(workflow_fields) def get(self, app_model: App): """ @@ -34,7 +34,7 @@ class DraftWorkflowApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def post(self, app_model: App): """ Sync draft workflow @@ -55,7 +55,7 @@ class DefaultBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.WORKFLOW], app_engine=ChatbotAppEngine.WORKFLOW) + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def get(self, app_model: App): """ Get default block config @@ -72,7 +72,8 @@ class ConvertToWorkflowApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model: App): """ - Convert basic mode of chatbot app(expert mode) to workflow mode + Convert basic mode of chatbot app to workflow mode + Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ # convert to workflow mode diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 1c2c4cf5c7..d61ab6d6ae 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,12 +5,11 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None, - app_engine: ChatbotAppEngine = None): + mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -32,6 +31,9 @@ def get_app_model(view: Optional[Callable] = None, *, raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.CHANNEL: + raise AppNotFoundError() + if mode is not None: if isinstance(mode, list): modes = mode @@ -42,16 +44,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - if app_engine is not None: - if app_mode not in [AppMode.CHAT, AppMode.WORKFLOW]: - raise AppNotFoundError(f"App mode is not supported for {app_engine.value} app engine.") - - if app_mode == AppMode.CHAT: - # fetch current app model config - app_model_config = app_model.app_model_config - if not app_model_config or app_model_config.chatbot_app_engine != app_engine.value: - raise AppNotFoundError(f"{app_engine.value} app engine is not supported.") - kwargs['app_model'] = app_model return view_func(*args, **kwargs) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 920d9141ae..7d6231270f 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -34,8 +34,7 @@ class InstalledAppsListApi(Resource): 'is_pinned': installed_app.is_pinned, 'last_used_at': installed_app.last_used_at, 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id, - 'is_agent': installed_app.is_agent + 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id } for installed_app in installed_apps ] diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6ba04d603a..3c28980f51 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,3 +1,6 @@ +import json + +import yaml from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse @@ -6,6 +9,7 @@ from controllers.console import api from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from models.model import App, RecommendedApp +from services.workflow_service import WorkflowService app_fields = { 'id': fields.String, @@ -23,8 +27,7 @@ recommended_app_fields = { 'privacy_policy': fields.String, 'category': fields.String, 'position': fields.Integer, - 'is_listed': fields.Boolean, - 'is_agent': fields.Boolean + 'is_listed': fields.Boolean } recommended_app_list_fields = { @@ -73,8 +76,7 @@ class RecommendedAppListApi(Resource): 'privacy_policy': site.privacy_policy, 'category': recommended_app.category, 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed, - "is_agent": app.is_agent + 'is_listed': recommended_app.is_listed } recommended_apps_result.append(recommended_app_result) @@ -84,27 +86,6 @@ class RecommendedAppListApi(Resource): class RecommendedAppApi(Resource): - model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - } - - app_simple_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'mode': fields.String, - 'app_model_config': fields.Nested(model_config_fields), - } - - @marshal_with(app_simple_detail_fields) def get(self, app_id): app_id = str(app_id) @@ -118,11 +99,38 @@ class RecommendedAppApi(Resource): raise AppNotFoundError # get app detail - app = db.session.query(App).filter(App.id == app_id).first() - if not app or not app.is_public: + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: raise AppNotFoundError - return app + app_model_config = app_model.app_model_config + + export_data = { + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": app_model.icon, + "icon_background": app_model.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + # get draft workflow + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return { + 'id': app_model.id, + 'name': app_model.name, + 'icon': app_model.icon, + 'icon_background': app_model.icon_background, + 'mode': app_model.mode, + 'export_data': yaml.dump(export_data) + } api.add_resource(RecommendedAppListApi, '/explore/apps') diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6e28247d38..0db84d3b69 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -235,7 +235,7 @@ class ProviderManager: if available_models: found = False for available_model in available_models: - if available_model.model == "gpt-3.5-turbo-1106": + if available_model.model == "gpt-4": default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index e6c1272086..75b68d24fc 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -42,14 +42,10 @@ app_detail_fields = { 'id': fields.String, 'name': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'created_at': TimestampField } @@ -67,12 +63,8 @@ app_partial_fields = { 'id': fields.String, 'name': fields.String, 'mode': fields.String, - 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), 'created_at': TimestampField } @@ -122,10 +114,6 @@ app_detail_fields_with_site = { 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'api_rpm': fields.Integer, - 'api_rph': fields.Integer, - 'is_agent': fields.Boolean, - 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 821d3c0ade..35cc5a6475 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -17,8 +17,7 @@ installed_app_fields = { 'is_pinned': fields.Boolean, 'last_used_at': TimestampField, 'editable': fields.Boolean, - 'uninstallable': fields.Boolean, - 'is_agent': fields.Boolean, + 'uninstallable': fields.Boolean } installed_app_list_fields = { diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 9e04fef288..7255b4b5fa 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -107,7 +107,6 @@ def upgrade(): batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('chatbot_app_engine', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False)) batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) with op.batch_alter_table('messages', schema=None) as batch_op: @@ -123,7 +122,6 @@ def downgrade(): with op.batch_alter_table('app_model_configs', schema=None) as batch_op: batch_op.drop_column('workflow_id') - batch_op.drop_column('chatbot_app_engine') with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.drop_index('workflow_version_idx') diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py new file mode 100644 index 0000000000..c302e8b530 --- /dev/null +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -0,0 +1,70 @@ +"""set model config column nullable + +Revision ID: cc04d0998d4d +Revises: b289e2408ee2 +Create Date: 2024-02-27 03:47:47.376325 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'cc04d0998d4d' +down_revision = 'b289e2408ee2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default='0', + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.alter_column('api_rpm', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + batch_op.alter_column('api_rph', + existing_type=sa.Integer(), + server_default=None, + nullable=False) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.alter_column('configs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 1e66fd6c88..713d8da577 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,7 +31,9 @@ class AppMode(Enum): COMPLETION = 'completion' WORKFLOW = 'workflow' CHAT = 'chat' - AGENT = 'agent' + ADVANCED_CHAT = 'advanced-chat' + AGENT_CHAT = 'agent-chat' + CHANNEL = 'channel' @classmethod def value_of(cls, value: str) -> 'AppMode': @@ -64,8 +66,8 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False) - api_rph = db.Column(db.Integer, nullable=False) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -92,19 +94,7 @@ class App(db.Model): def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - - @property - def is_agent(self) -> bool: - app_model_config = self.app_model_config - if not app_model_config: - return False - if not app_model_config.agent_mode: - return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: - return True - return False - + @property def deleted_tools(self) -> list: # get agent mode tools @@ -153,11 +143,6 @@ class App(db.Model): return deleted_tools -class ChatbotAppEngine(Enum): - NORMAL = 'normal' - WORKFLOW = 'workflow' - - class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( @@ -167,9 +152,9 @@ class AppModelConfig(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) - configs = db.Column(db.JSON, nullable=False) + provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) + configs = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) opening_statement = db.Column(db.Text) @@ -191,7 +176,6 @@ class AppModelConfig(db.Model): dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) - chatbot_app_engine = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) workflow_id = db.Column(UUID) @property @@ -301,9 +285,6 @@ class AppModelConfig(db.Model): def to_dict(self) -> dict: return { - "provider": "", - "model_id": "", - "configs": {}, "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, @@ -327,9 +308,6 @@ class AppModelConfig(db.Model): } def from_model_config_dict(self, model_config: dict): - self.provider = "" - self.model_id = "" - self.configs = {} self.opening_statement = model_config['opening_statement'] self.suggested_questions = json.dumps(model_config['suggested_questions']) self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) @@ -358,15 +336,13 @@ class AppModelConfig(db.Model): if model_config.get('dataset_configs') else None self.file_upload = json.dumps(model_config.get('file_upload')) \ if model_config.get('file_upload') else None + self.workflow_id = model_config.get('workflow_id') return self def copy(self): new_app_model_config = AppModelConfig( id=self.id, app_id=self.app_id, - provider="", - model_id="", - configs={}, opening_statement=self.opening_statement, suggested_questions=self.suggested_questions, suggested_questions_after_answer=self.suggested_questions_after_answer, @@ -385,7 +361,8 @@ class AppModelConfig(db.Model): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, + workflow_id=self.workflow_id ) return new_app_model_config @@ -446,12 +423,6 @@ class InstalledApp(db.Model): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant - @property - def is_agent(self) -> bool: - app = self.app - if not app: - return False - return app.is_agent class Conversation(db.Model): __tablename__ = 'conversations' diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c6f0bed008..ed24762dd8 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig, ChatbotAppEngine, Site +from models.model import App, AppMode, AppModelConfig, Site from models.workflow import Workflow, WorkflowType @@ -85,8 +85,6 @@ class WorkflowConverter: new_app_model_config.chat_prompt_config = '' new_app_model_config.completion_prompt_config = '' new_app_model_config.dataset_configs = '' - new_app_model_config.chatbot_app_engine = ChatbotAppEngine.WORKFLOW.value \ - if app_model.mode == AppMode.CHAT.value else ChatbotAppEngine.NORMAL.value new_app_model_config.workflow_id = workflow.id db.session.add(new_app_model_config) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4f7262b7d6..3143818d12 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,9 +1,10 @@ import json from datetime import datetime +from typing import Optional from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, ChatbotAppEngine +from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter @@ -58,6 +59,40 @@ class WorkflowService: # return draft workflow return workflow + def publish_draft_workflow(self, app_model: App, + account: Account, + draft_workflow: Optional[Workflow] = None) -> Workflow: + """ + Publish draft workflow + + :param app_model: App instance + :param account: Account instance + :param draft_workflow: Workflow instance + """ + if not draft_workflow: + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('No valid workflow found.') + + # create new workflow + workflow = Workflow( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type=draft_workflow.type, + version=str(datetime.utcnow()), + graph=draft_workflow.graph, + created_by=account.id + ) + + # commit db session changes + db.session.add(workflow) + db.session.commit() + + # return new workflow + return workflow + def get_default_block_configs(self) -> dict: """ Get default block configs @@ -77,11 +112,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode == AppMode.CHAT.value: - # check if chatbot app is in basic mode - if app_model.app_model_config.chatbot_app_engine != ChatbotAppEngine.NORMAL: - raise ValueError('Chatbot app already in workflow mode') - elif app_model.mode != AppMode.COMPLETION.value: + if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') # convert to workflow From 9f42892b42cdb88e7a1c71383f6d482891ec98b1 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:23:20 +0800 Subject: [PATCH 064/200] lint fix --- api/constants/languages.py | 2 -- .../versions/cc04d0998d4d_set_model_config_column_nullable.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/api/constants/languages.py b/api/constants/languages.py index 0147dd8d70..dd8a29eaef 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,6 +1,4 @@ -import json -from models.model import AppModelConfig languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN'] diff --git a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py index c302e8b530..aefbe43f14 100644 --- a/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py +++ b/api/migrations/versions/cc04d0998d4d_set_model_config_column_nullable.py @@ -5,8 +5,8 @@ Revises: b289e2408ee2 Create Date: 2024-02-27 03:47:47.376325 """ -from alembic import op import sqlalchemy as sa +from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. From c13e8077ba6bd364cb7058b02ed4cac3fa692e95 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:27:46 +0800 Subject: [PATCH 065/200] fix agent app converter command --- api/commands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/commands.py b/api/commands.py index e376d222c6..73325620ee 100644 --- a/api/commands.py +++ b/api/commands.py @@ -405,12 +405,12 @@ def convert_to_agent_apps(): click.echo('Converting app: {}'.format(app.id)) try: - app.mode = AppMode.AGENT.value + app.mode = AppMode.AGENT_CHAT.value db.session.commit() # update conversation mode to agent db.session.query(Conversation).filter(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT.value} + {Conversation.mode: AppMode.AGENT_CHAT.value} ) db.session.commit() From 84c3ec0ea71bdcef0bde6901ddf4b6e3a64f2f56 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 13:40:18 +0800 Subject: [PATCH 066/200] site init move to event handler --- api/controllers/console/app/app.py | 172 +++++------------- api/events/event_handlers/__init__.py | 1 + .../create_site_record_when_app_created.py | 20 ++ api/services/workflow/workflow_converter.py | 13 +- 4 files changed, 66 insertions(+), 140 deletions(-) create mode 100644 api/events/event_handlers/create_site_record_when_app_created.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4c218bef1b..4d88733d5f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,4 @@ import json -import logging from datetime import datetime from typing import cast @@ -8,29 +7,24 @@ from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden -from constants.languages import languages from constants.model_template import default_app_templates from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, - template_list_fields, ) from libs.login import login_required -from models.model import App, AppModelConfig, Site, AppMode -from services.app_model_config_service import AppModelConfigService +from models.model import App, AppModelConfig, AppMode from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager @@ -102,95 +96,47 @@ class AppListApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - # TODO: MOVE TO IMPORT API - if args['model_config'] is not None: - # validate config - model_config_dict = args['model_config'] + if 'mode' not in args or args['mode'] is None: + abort(400, message="mode is required") - # Get provider configurations - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id) + app_mode = AppMode.value_of(args['mode']) - # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=ModelType.LLM, - only_active=True - ) + app_template = default_app_templates[app_mode] - # check if model is available - available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models] - provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}" - if provider_model not in available_models_names: - if not default_model_entity: - raise ProviderNotInitializeError( - "No Default System Reasoning Model available. Please configure " - "in the Settings -> Model Provider.") - else: - model_config_dict["model"]["provider"] = default_model_entity.provider.provider - model_config_dict["model"]["name"] = default_model_entity.model + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: + # get model provider + model_manager = ModelManager() - model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - account=current_user, - config=model_config_dict, - app_mode=args['mode'] - ) + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None - app = App( - enable_site=True, - enable_api=True, - is_demo=False, - api_rpm=0, - api_rph=0, - status='normal' - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_configuration) - else: - if 'mode' not in args or args['mode'] is None: - abort(400, message="mode is required") - - app_mode = AppMode.value_of(args['mode']) - - app_template = default_app_templates[app_mode] - - # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: - # get model provider - model_manager = ModelManager() - - # get default model instance - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - if model_instance.model == default_model_config['model']['name']: - default_model_dict = default_model_config['model'] - else: - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} - } - else: + if model_instance: + if model_instance.model == default_model_config['model']['name']: default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - default_model_config['model'] = json.dumps(default_model_dict) + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] - app = App(**app_template['app']) - app_model_config = AppModelConfig(**default_model_config) + default_model_config['model'] = json.dumps(default_model_dict) + app = App(**app_template['app']) app.name = args['name'] app.mode = args['mode'] app.icon = args['icon'] @@ -200,26 +146,14 @@ class AppListApi(Resource): db.session.add(app) db.session.flush() + app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id db.session.add(app_model_config) db.session.flush() app.app_model_config_id = app_model_config.id - account = current_user - - site = Site( - app_id=app.id, - title=app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(app) + app_was_created.send(app, account=current_user) return app, 201 @@ -262,21 +196,16 @@ class AppImportApi(Resource): "when mode is advanced-chat or workflow") app = App( + tenant_id=current_user.current_tenant_id, + mode=app_data.get('mode'), + name=args.get("name") if args.get("name") else app_data.get('name'), + icon=args.get("icon") if args.get("icon") else app_data.get('icon'), + icon_background=args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background'), enable_site=True, - enable_api=True, - is_demo=False, - api_rpm=0, - api_rph=0, - status='normal' + enable_api=True ) - app.tenant_id = current_user.current_tenant_id - app.mode = app_data.get('mode') - app.name = args.get("name") if args.get("name") else app_data.get('name') - app.icon = args.get("icon") if args.get("icon") else app_data.get('icon') - app.icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') - db.session.add(app) db.session.commit() @@ -295,20 +224,7 @@ class AppImportApi(Resource): app.app_model_config_id = app_model_config.id - account = current_user - - site = Site( - app_id=app.id, - title=app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(app) + app_was_created.send(app, account=current_user) return app, 201 diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 88d226d303..fdfb401bd4 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -2,6 +2,7 @@ from .clean_when_dataset_deleted import handle from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle +from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle from .generate_conversation_name_when_first_message_created import handle diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py new file mode 100644 index 0000000000..25fba591d0 --- /dev/null +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -0,0 +1,20 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import Site + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create site record when an app is created.""" + app = sender + account = kwargs.get('account') + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index ed24762dd8..72c6d3f719 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -93,18 +93,7 @@ class WorkflowConverter: new_app.app_model_config_id = new_app_model_config.id db.session.commit() - site = Site( - app_id=new_app.id, - title=new_app.name, - default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) - ) - - db.session.add(site) - db.session.commit() - - app_was_created.send(new_app) + app_was_created.send(new_app, account=account) return new_app From 8b529a3ec7f912ac4a50c6b2463efda7c8363763 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:25:39 +0800 Subject: [PATCH 067/200] refactor app api --- api/controllers/console/app/app.py | 210 ++----------- .../console/explore/recommended_app.py | 28 +- api/services/app_service.py | 281 ++++++++++++++++++ api/services/workflow/workflow_converter.py | 2 +- 4 files changed, 309 insertions(+), 212 deletions(-) create mode 100644 api/services/app_service.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4d88733d5f..6c0d0ca9a6 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,29 +1,18 @@ -import json -from datetime import datetime -from typing import cast - -import yaml from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Forbidden, BadRequest -from constants.model_template import default_app_templates from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import ProviderTokenNotInitError -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted -from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, ) from libs.login import login_required +from services.app_service import AppService from models.model import App, AppModelConfig, AppMode from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -49,32 +38,9 @@ class AppListApi(Resource): parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() - filters = [ - App.tenant_id == current_user.current_tenant_id, - App.is_universal == False - ] - - if args['mode'] == 'workflow': - filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': - filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent': - filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': - filters.append(App.mode == AppMode.CHANNEL.value) - else: - pass - - if 'name' in args and args['name']: - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - - app_models = db.paginate( - db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False - ) + # get app list + app_service = AppService() + app_models = app_service.get_paginate_apps(current_user.current_tenant_id, args) return app_models @@ -97,63 +63,10 @@ class AppListApi(Resource): raise Forbidden() if 'mode' not in args or args['mode'] is None: - abort(400, message="mode is required") + raise BadRequest("mode is required") - app_mode = AppMode.value_of(args['mode']) - - app_template = default_app_templates[app_mode] - - # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: - # get model provider - model_manager = ModelManager() - - # get default model instance - try: - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - except ProviderTokenNotInitError: - model_instance = None - - if model_instance: - if model_instance.model == default_model_config['model']['name']: - default_model_dict = default_model_config['model'] - else: - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} - } - else: - default_model_dict = default_model_config['model'] - - default_model_config['model'] = json.dumps(default_model_dict) - - app = App(**app_template['app']) - app.name = args['name'] - app.mode = args['mode'] - app.icon = args['icon'] - app.icon_background = args['icon_background'] - app.tenant_id = current_user.current_tenant_id - - db.session.add(app) - db.session.flush() - - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() - - app.app_model_config_id = app_model_config.id - - app_was_created.send(app, account=current_user) + app_service = AppService() + app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -177,54 +90,8 @@ class AppImportApi(Resource): parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - try: - import_data = yaml.safe_load(args['data']) - except yaml.YAMLError as e: - raise ValueError("Invalid YAML format in data argument.") - - app_data = import_data.get('app') - model_config_data = import_data.get('model_config') - workflow_graph = import_data.get('workflow_graph') - - if not app_data or not model_config_data: - raise ValueError("Missing app or model_config in data argument") - - app_mode = AppMode.value_of(app_data.get('mode')) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - if not workflow_graph: - raise ValueError("Missing workflow_graph in data argument " - "when mode is advanced-chat or workflow") - - app = App( - tenant_id=current_user.current_tenant_id, - mode=app_data.get('mode'), - name=args.get("name") if args.get("name") else app_data.get('name'), - icon=args.get("icon") if args.get("icon") else app_data.get('icon'), - icon_background=args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background'), - enable_site=True, - enable_api=True - ) - - db.session.add(app) - db.session.commit() - - if workflow_graph: - workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, current_user) - published_workflow = workflow_service.publish_draft_workflow(app, current_user, draft_workflow) - model_config_data['workflow_id'] = published_workflow.id - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - - app_was_created.send(app, account=current_user) + app_service = AppService() + app = app_service.import_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -286,13 +153,8 @@ class AppApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - db.session.delete(app_model) - db.session.commit() - - # todo delete related data?? - # model_config, site, api_token, conversation, message, message_feedback, message_annotation - - app_was_deleted.send(app_model) + app_service = AppService() + app_service.delete_app(app_model) return {'result': 'success'}, 204 @@ -304,28 +166,10 @@ class AppExportApi(Resource): @get_app_model def get(self, app_model): """Export app""" - app_model_config = app_model.app_model_config - - export_data = { - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background - }, - "model_config": app_model_config.to_dict(), - } - - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) - else: - # get draft workflow - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_service = AppService() return { - "data": yaml.dump(export_data) + "data": app_service.export_app(app_model) } @@ -340,9 +184,9 @@ class AppNameApi(Resource): parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() - app_model.name = args.get('name') - app_model.updated_at = datetime.utcnow() - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_name(app_model, args.get('name')) + return app_model @@ -358,10 +202,8 @@ class AppIconApi(Resource): parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - app_model.icon = args.get('icon') - app_model.icon_background = args.get('icon_background') - app_model.updated_at = datetime.utcnow() - db.session.commit() + app_service = AppService() + app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) return app_model @@ -377,12 +219,9 @@ class AppSiteStatus(Resource): parser.add_argument('enable_site', type=bool, required=True, location='json') args = parser.parse_args() - if args.get('enable_site') == app_model.enable_site: - return app_model + app_service = AppService() + app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) - app_model.enable_site = args.get('enable_site') - app_model.updated_at = datetime.utcnow() - db.session.commit() return app_model @@ -397,12 +236,9 @@ class AppApiStatus(Resource): parser.add_argument('enable_api', type=bool, required=True, location='json') args = parser.parse_args() - if args.get('enable_api') == app_model.enable_api: - return app_model + app_service = AppService() + app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) - app_model.enable_api = args.get('enable_api') - app_model.updated_at = datetime.utcnow() - db.session.commit() return app_model diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 3c28980f51..8190f7828d 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,6 +1,3 @@ -import json - -import yaml from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse @@ -9,7 +6,7 @@ from controllers.console import api from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from models.model import App, RecommendedApp -from services.workflow_service import WorkflowService +from services.app_service import AppService app_fields = { 'id': fields.String, @@ -103,25 +100,8 @@ class RecommendedAppApi(Resource): if not app_model or not app_model.is_public: raise AppNotFoundError - app_model_config = app_model.app_model_config - - export_data = { - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background - }, - "model_config": app_model_config.to_dict(), - } - - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) - else: - # get draft workflow - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_service = AppService() + export_str = app_service.export_app(app_model) return { 'id': app_model.id, @@ -129,7 +109,7 @@ class RecommendedAppApi(Resource): 'icon': app_model.icon, 'icon_background': app_model.icon_background, 'mode': app_model.mode, - 'export_data': yaml.dump(export_data) + 'export_data': export_str } diff --git a/api/services/app_service.py b/api/services/app_service.py new file mode 100644 index 0000000000..e80c720d4c --- /dev/null +++ b/api/services/app_service.py @@ -0,0 +1,281 @@ +import json +from datetime import datetime +from typing import cast + +import yaml + +from constants.model_template import default_app_templates +from core.errors.error import ProviderTokenNotInitError +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from events.app_event import app_was_created, app_was_deleted +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, AppModelConfig +from services.workflow_service import WorkflowService + + +class AppService: + def get_paginate_apps(self, tenant_id: str, args: dict) -> list[App]: + """ + Get app list with pagination + :param tenant_id: tenant id + :param args: request args + :return: + """ + filters = [ + App.tenant_id == tenant_id, + App.is_universal == False + ] + + if args['mode'] == 'workflow': + filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) + elif args['mode'] == 'chat': + filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) + elif args['mode'] == 'agent': + filters.append(App.mode == AppMode.AGENT_CHAT.value) + elif args['mode'] == 'channel': + filters.append(App.mode == AppMode.CHANNEL.value) + + if 'name' in args and args['name']: + name = args['name'][:30] + filters.append(App.name.ilike(f'%{name}%')) + + app_models = db.paginate( + db.select(App).where(*filters).order_by(App.created_at.desc()), + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return app_models + + def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Create app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + app_mode = AppMode.value_of(args['mode']) + app_template = default_app_templates[app_mode] + + # get model config + default_model_config = app_template['model_config'] + if 'model' in default_model_config: + # get model provider + model_manager = ModelManager() + + # get default model instance + try: + model_instance = model_manager.get_default_model_instance( + tenant_id=account.current_tenant_id, + model_type=ModelType.LLM + ) + except ProviderTokenNotInitError: + model_instance = None + + if model_instance: + if model_instance.model == default_model_config['model']['name']: + default_model_dict = default_model_config['model'] + else: + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + default_model_dict = { + 'provider': model_instance.provider, + 'name': model_instance.model, + 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), + 'completion_params': {} + } + else: + default_model_dict = default_model_config['model'] + + default_model_config['model'] = json.dumps(default_model_dict) + + app = App(**app_template['app']) + app.name = args['name'] + app.mode = args['mode'] + app.icon = args['icon'] + app.icon_background = args['icon_background'] + app.tenant_id = account.current_tenant_id + + db.session.add(app) + db.session.flush() + + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id + + app_was_created.send(app, account=account) + + return app + + def import_app(self, tenant_id: str, args: dict, account: Account) -> App: + """ + Import app + :param tenant_id: tenant id + :param args: request args + :param account: Account instance + """ + try: + import_data = yaml.safe_load(args['data']) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML format in data argument.") + + app_data = import_data.get('app') + model_config_data = import_data.get('model_config') + workflow_graph = import_data.get('workflow_graph') + + if not app_data or not model_config_data: + raise ValueError("Missing app or model_config in data argument") + + app_mode = AppMode.value_of(app_data.get('mode')) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if not workflow_graph: + raise ValueError("Missing workflow_graph in data argument " + "when mode is advanced-chat or workflow") + + app = App( + tenant_id=tenant_id, + mode=app_data.get('mode'), + name=args.get("name") if args.get("name") else app_data.get('name'), + icon=args.get("icon") if args.get("icon") else app_data.get('icon'), + icon_background=args.get("icon_background") if args.get("icon_background") \ + else app_data.get('icon_background'), + enable_site=True, + enable_api=True + ) + + db.session.add(app) + db.session.commit() + + if workflow_graph: + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, account) + published_workflow = workflow_service.publish_draft_workflow(app, account, draft_workflow) + model_config_data['workflow_id'] = published_workflow.id + + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_was_created.send(app, account=account) + + return app + + def export_app(self, app: App) -> str: + """ + Export app + :param app: App instance + :return: + """ + app_model_config = app.app_model_config + + export_data = { + "app": { + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background + }, + "model_config": app_model_config.to_dict(), + } + + if app_model_config.workflow_id: + export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow_graph'] = json.loads(workflow.graph) + + return yaml.dump(export_data) + + def update_app_name(self, app: App, name: str) -> App: + """ + Update app name + :param app: App instance + :param name: new name + :return: App instance + """ + app.name = name + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + """ + Update app icon + :param app: App instance + :param icon: new icon + :param icon_background: new icon_background + :return: App instance + """ + app.icon = icon + app.icon_background = icon_background + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_site_status(self, app: App, enable_site: bool) -> App: + """ + Update app site status + :param app: App instance + :param enable_site: enable site status + :return: App instance + """ + if enable_site == app.enable_site: + return app + + app.enable_site = enable_site + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def update_app_api_status(self, app: App, enable_api: bool) -> App: + """ + Update app api status + :param app: App instance + :param enable_api: enable api status + :return: App instance + """ + if enable_api == app.enable_api: + return app + + app.enable_api = enable_api + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + def delete_app(self, app: App) -> None: + """ + Delete app + :param app: App instance + """ + db.session.delete(app) + db.session.commit() + + app_was_deleted.send(app) + + # todo async delete related data by event + # app_model_configs, site, api_tokens, installed_apps, recommended_apps BY app + # app_annotation_hit_histories, app_annotation_settings, app_dataset_joins BY app + # workflows, workflow_runs, workflow_node_executions, workflow_app_logs BY app + # conversations, pinned_conversations, messages BY app + # message_feedbacks, message_annotations, message_chains BY message + # message_agent_thoughts, message_files, saved_messages BY message + + diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 72c6d3f719..fb6cf1fd5a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -21,7 +21,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow, WorkflowType From 4f50f113dd192e6dfd5d4164bf7fc0e0a26962fb Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:25:49 +0800 Subject: [PATCH 068/200] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index e80c720d4c..f3a12a8b9c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -7,7 +7,7 @@ import yaml from constants.model_template import default_app_templates from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db From a457faa2bf488ca7ad8dcee4b2a4103c0f3da506 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:28:40 +0800 Subject: [PATCH 069/200] trigger app_model_config_was_updated when app import --- api/services/app_service.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index f3a12a8b9c..375c102114 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -9,7 +9,7 @@ from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted +from events.app_event import app_was_created, app_was_deleted, app_model_config_was_updated from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, AppModelConfig @@ -171,6 +171,11 @@ class AppService: app_was_created.send(app, account=account) + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) + return app def export_app(self, app: App) -> str: From 742b87df5e3be98cfa47ecff3b7ae160f0f060ff Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:29:17 +0800 Subject: [PATCH 070/200] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 375c102114..a83c7e6ac4 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -9,7 +9,7 @@ from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created, app_was_deleted, app_model_config_was_updated +from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, AppModelConfig From 7d51d6030be5896bb3f4299cff4387d6b50255d4 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 14:36:42 +0800 Subject: [PATCH 071/200] remove publish workflow when app import --- api/services/app_service.py | 7 ++----- api/services/workflow_service.py | 34 ++++++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index a83c7e6ac4..6955a6dccb 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -155,10 +155,9 @@ class AppService: db.session.commit() if workflow_graph: + # init draft workflow workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow(app, workflow_graph, account) - published_workflow = workflow_service.publish_draft_workflow(app, account, draft_workflow) - model_config_data['workflow_id'] = published_workflow.id + workflow_service.sync_draft_workflow(app, workflow_graph, account) app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(model_config_data) @@ -282,5 +281,3 @@ class AppService: # conversations, pinned_conversations, messages BY app # message_feedbacks, message_annotations, message_chains BY message # message_agent_thoughts, message_files, saved_messages BY message - - diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3143818d12..dac88d6396 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -59,11 +59,11 @@ class WorkflowService: # return draft workflow return workflow - def publish_draft_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, + account: Account, + draft_workflow: Optional[Workflow] = None) -> Workflow: """ - Publish draft workflow + Publish workflow from draft :param app_model: App instance :param account: Account instance @@ -76,6 +76,8 @@ class WorkflowService: if not draft_workflow: raise ValueError('No valid workflow found.') + # TODO check if the workflow is valid + # create new workflow workflow = Workflow( tenant_id=app_model.tenant_id, @@ -90,6 +92,30 @@ class WorkflowService: db.session.add(workflow) db.session.commit() + app_model_config = app_model.app_model_config + + # create new app model config record + new_app_model_config = app_model_config.copy() + new_app_model_config.id = None + new_app_model_config.app_id = app_model.id + new_app_model_config.external_data_tools = '' + new_app_model_config.model = '' + new_app_model_config.user_input_form = '' + new_app_model_config.dataset_query_variable = None + new_app_model_config.pre_prompt = None + new_app_model_config.agent_mode = '' + new_app_model_config.prompt_type = 'simple' + new_app_model_config.chat_prompt_config = '' + new_app_model_config.completion_prompt_config = '' + new_app_model_config.dataset_configs = '' + new_app_model_config.workflow_id = workflow.id + + db.session.add(new_app_model_config) + db.session.flush() + + app_model.app_model_config_id = new_app_model_config.id + db.session.commit() + # return new workflow return workflow From 03749917f04be9ef5473ca3e72f84e62cab24c98 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 18:03:47 +0800 Subject: [PATCH 072/200] add workflow app log api --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/workflow.py | 36 +++++++++++ .../console/app/workflow_app_log.py | 41 ++++++++++++ api/fields/end_user_fields.py | 8 +++ api/fields/workflow_app_log_fields.py | 25 ++++++++ api/fields/workflow_fields.py | 13 ++++ api/models/__init__.py | 45 +++++++++++++- api/models/workflow.py | 20 +++++- api/services/app_service.py | 3 +- api/services/workflow_app_service.py | 62 +++++++++++++++++++ api/services/workflow_service.py | 24 ++++++- 12 files changed, 276 insertions(+), 7 deletions(-) create mode 100644 api/controllers/console/app/workflow_app_log.py create mode 100644 api/fields/end_user_fields.py create mode 100644 api/fields/workflow_app_log_fields.py create mode 100644 api/services/workflow_app_service.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 649df278ec..a6f803785a 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ api = ExternalApi(bp) from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic, workflow) + model_config, site, statistic, workflow, workflow_app_log) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6c0d0ca9a6..898fd4f7c4 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -40,9 +40,9 @@ class AppListApi(Resource): # get app list app_service = AppService() - app_models = app_service.get_paginate_apps(current_user.current_tenant_id, args) + app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) - return app_models + return app_pagination @setup_required @login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6023d0ba45..8e51ae8cbd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -51,6 +51,41 @@ class DraftWorkflowApi(Resource): } +class PublishedWorkflowApi(Resource): + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_fields) + def get(self, app_model: App): + """ + Get published workflow + """ + # fetch published workflow by app_model + workflow_service = WorkflowService() + workflow = workflow_service.get_published_workflow(app_model=app_model) + + # return workflow, if not found, return None + return workflow + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Publish workflow + """ + workflow_service = WorkflowService() + workflow_service.publish_workflow(app_model=app_model, account=current_user) + + return { + "result": "success" + } + + + class DefaultBlockConfigApi(Resource): @setup_required @login_required @@ -88,5 +123,6 @@ class ConvertToWorkflowApi(Resource): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py new file mode 100644 index 0000000000..87614d549d --- /dev/null +++ b/api/controllers/console/app/workflow_app_log.py @@ -0,0 +1,41 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from libs.login import login_required +from models.model import AppMode, App +from services.workflow_app_service import WorkflowAppService + + +class WorkflowAppLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_app_log_pagination_fields) + def get(self, app_model: App): + """ + Get workflow app logs + """ + parser = reqparse.RequestParser() + parser.add_argument('keyword', type=str, location='args') + parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') + parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') + parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + args = parser.parse_args() + + # get paginate workflow app logs + workflow_app_service = WorkflowAppService() + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( + app_model=app_model, + args=args + ) + + return workflow_app_log_pagination + + +api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py new file mode 100644 index 0000000000..ee630c12c2 --- /dev/null +++ b/api/fields/end_user_fields.py @@ -0,0 +1,8 @@ +from flask_restful import fields + +simple_end_user_fields = { + 'id': fields.String, + 'type': fields.String, + 'is_anonymous': fields.Boolean, + 'session_id': fields.String, +} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py new file mode 100644 index 0000000000..6862f0411d --- /dev/null +++ b/api/fields/workflow_app_log_fields.py @@ -0,0 +1,25 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from fields.workflow_fields import workflow_run_fields +from libs.helper import TimestampField + + +workflow_app_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), + "created_from": fields.String, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField +} + +workflow_app_log_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index decdc0567f..091f293150 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,3 +13,16 @@ workflow_fields = { 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_at': TimestampField } + +workflow_run_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "finished_at": TimestampField +} \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py index 44d37d3052..47eec53542 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1 +1,44 @@ -# -*- coding:utf-8 -*- \ No newline at end of file +from enum import Enum + + +class CreatedByRole(Enum): + """ + Enum class for createdByRole + """ + ACCOUNT = "account" + END_USER = "end_user" + + @classmethod + def value_of(cls, value: str) -> 'CreatedByRole': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdByRole value {value}') + + +class CreatedFrom(Enum): + """ + Enum class for createdFrom + """ + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + + @classmethod + def value_of(cls, value: str) -> 'CreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for role in cls: + if role.value == value: + return role + raise ValueError(f'invalid createdFrom value {value}') diff --git a/api/models/workflow.py b/api/models/workflow.py index 251f33b0c0..41266fe9f5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,6 +5,7 @@ from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db from models.account import Account +from models.model import EndUser class CreatedByRole(Enum): @@ -148,6 +149,7 @@ class WorkflowRunStatus(Enum): RUNNING = 'running' SUCCEEDED = 'succeeded' FAILED = 'failed' + STOPPED = 'stopped' @classmethod def value_of(cls, value: str) -> 'WorkflowRunStatus': @@ -184,7 +186,7 @@ class WorkflowRun(db.Model): - version (string) Version - graph (text) Workflow canvas configuration (JSON) - inputs (text) Input parameters - - status (string) Execution status, `running` / `succeeded` / `failed` + - status (string) Execution status, `running` / `succeeded` / `failed` / `stopped` - outputs (text) `optional` Output content - error (string) `optional` Error reason - elapsed_time (float) `optional` Time consumption (s) @@ -366,3 +368,19 @@ class WorkflowAppLog(db.Model): created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def workflow_run(self): + return WorkflowRun.query.get(self.workflow_run_id) + + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/app_service.py b/api/services/app_service.py index 6955a6dccb..5de87dbad5 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import cast import yaml +from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates from core.errors.error import ProviderTokenNotInitError @@ -17,7 +18,7 @@ from services.workflow_service import WorkflowService class AppService: - def get_paginate_apps(self, tenant_id: str, args: dict) -> list[App]: + def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination: """ Get app list with pagination :param tenant_id: tenant id diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py new file mode 100644 index 0000000000..5897fcf182 --- /dev/null +++ b/api/services/workflow_app_service.py @@ -0,0 +1,62 @@ +from flask_sqlalchemy.pagination import Pagination +from sqlalchemy import or_, and_ + +from extensions.ext_database import db +from models import CreatedByRole +from models.model import App, EndUser +from models.workflow import WorkflowAppLog, WorkflowRunStatus, WorkflowRun + + +class WorkflowAppService: + + def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: + """ + Get paginate workflow app logs + :param app: app model + :param args: request args + :return: + """ + query = ( + db.select(WorkflowAppLog) + .where( + WorkflowAppLog.tenant_id == app_model.tenant_id, + WorkflowAppLog.app_id == app_model.id + ) + ) + + status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None + if args['keyword'] or status: + query = query.join( + WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id + ) + + if args['keyword']: + keyword_val = f"%{args['keyword'][:30]}%" + keyword_conditions = [ + WorkflowRun.inputs.ilike(keyword_val), + WorkflowRun.outputs.ilike(keyword_val), + # filter keyword by end user session id if created by end user role + and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_val)) + ] + + query = query.outerjoin( + EndUser, + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + ).filter(or_(*keyword_conditions)) + + if status: + # join with workflow_run and filter by status + query = query.filter( + WorkflowRun.status == status.value + ) + + query = query.order_by(WorkflowAppLog.created_at.desc()) + + pagination = db.paginate( + query, + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return pagination diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index dac88d6396..ae6e4c46d3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,7 +15,7 @@ class WorkflowService: Workflow Service """ - def get_draft_workflow(self, app_model: App) -> Workflow: + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ Get draft workflow """ @@ -29,6 +29,26 @@ class WorkflowService: # return draft workflow return workflow + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + app_model_config = app_model.app_model_config + + if not app_model_config.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model_config.workflow_id + ).first() + + # return published workflow + return workflow + + def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: """ Sync draft workflow @@ -116,6 +136,8 @@ class WorkflowService: app_model.app_model_config_id = new_app_model_config.id db.session.commit() + # TODO update app related datasets + # return new workflow return workflow From bf4a5f6b33f8516bc0392e3c2d07284393c2914f Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 18:04:01 +0800 Subject: [PATCH 073/200] lint fix --- api/controllers/console/app/workflow_app_log.py | 2 +- api/fields/workflow_app_log_fields.py | 1 - api/services/workflow_app_service.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 87614d549d..6d1709ed8e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -7,7 +7,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required -from models.model import AppMode, App +from models.model import App, AppMode from services.workflow_app_service import WorkflowAppService diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 6862f0411d..8f3998d90a 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -5,7 +5,6 @@ from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_run_fields from libs.helper import TimestampField - workflow_app_log_partial_fields = { "id": fields.String, "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 5897fcf182..0476788375 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -1,10 +1,10 @@ from flask_sqlalchemy.pagination import Pagination -from sqlalchemy import or_, and_ +from sqlalchemy import and_, or_ from extensions.ext_database import db from models import CreatedByRole from models.model import App, EndUser -from models.workflow import WorkflowAppLog, WorkflowRunStatus, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus class WorkflowAppService: From 20cf075b2dacc54fc3d5ee713d3b94850f0a8db2 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 21:39:13 +0800 Subject: [PATCH 074/200] add workflow runs & workflow node executions api --- api/controllers/console/app/workflow.py | 60 +++++++++++- api/controllers/console/app/workflow_run.py | 80 ++++++++++++++++ api/fields/conversation_fields.py | 1 + api/fields/workflow_app_log_fields.py | 4 +- api/fields/workflow_fields.py | 13 --- api/fields/workflow_run_fields.py | 92 +++++++++++++++++++ .../versions/b289e2408ee2_add_workflow.py | 2 +- api/models/workflow.py | 45 ++++++++- api/services/workflow_run_service.py | 89 ++++++++++++++++++ 9 files changed, 365 insertions(+), 21 deletions(-) create mode 100644 api/controllers/console/app/workflow_run.py create mode 100644 api/fields/workflow_run_fields.py create mode 100644 api/services/workflow_run_service.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8e51ae8cbd..4fcf8daf6e 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -51,6 +51,62 @@ class DraftWorkflowApi(Resource): } +class DraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Run draft workflow + """ + # TODO + workflow_service = WorkflowService() + workflow_service.run_draft_workflow(app_model=app_model, account=current_user) + + # TODO + return { + "result": "success" + } + + +class WorkflowTaskStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, task_id: str): + """ + Stop workflow task + """ + # TODO + workflow_service = WorkflowService() + workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user) + + return { + "result": "success" + } + + +class DraftWorkflowNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App, node_id: str): + """ + Run draft workflow node + """ + # TODO + workflow_service = WorkflowService() + workflow_service.run_draft_workflow_node(app_model=app_model, node_id=node_id, account=current_user) + + # TODO + return { + "result": "success" + } + + class PublishedWorkflowApi(Resource): @setup_required @@ -85,7 +141,6 @@ class PublishedWorkflowApi(Resource): } - class DefaultBlockConfigApi(Resource): @setup_required @login_required @@ -123,6 +178,9 @@ class ConvertToWorkflowApi(Resource): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') +api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') +api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py new file mode 100644 index 0000000000..38e3d4d837 --- /dev/null +++ b/api/controllers/console/app/workflow_run.py @@ -0,0 +1,80 @@ +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.workflow_run_fields import workflow_run_detail_fields, workflow_run_pagination_fields, \ + workflow_run_node_execution_list_fields +from libs.helper import uuid_value +from libs.login import login_required +from models.model import App, AppMode +from services.workflow_run_service import WorkflowRunService + + +class WorkflowRunListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_pagination_fields) + def get(self, app_model: App): + """ + Get workflow run list + """ + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_paginate_workflow_runs( + app_model=app_model, + args=args + ) + + return result + + +class WorkflowRunDetailApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_detail_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run detail + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + + return workflow_run + + +class WorkflowRunNodeExecutionListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_list_fields) + def get(self, app_model: App, run_id): + """ + Get workflow run node execution list + """ + run_id = str(run_id) + + workflow_run_service = WorkflowRunService() + node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + + return { + 'data': node_executions + } + + +api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') +api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') +api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index afa486f1cd..747b0b86ab 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -66,6 +66,7 @@ message_detail_fields = { 'from_end_user_id': fields.String, 'from_account_id': fields.String, 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'workflow_run_id': fields.String, 'annotation': fields.Nested(annotation_fields, allow_null=True), 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), 'created_at': TimestampField, diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 8f3998d90a..e230c159fb 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -2,12 +2,12 @@ from flask_restful import fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields -from fields.workflow_fields import workflow_run_fields +from fields.workflow_run_fields import workflow_run_for_log_fields from libs.helper import TimestampField workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), "created_from": fields.String, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 091f293150..decdc0567f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,16 +13,3 @@ workflow_fields = { 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_at': TimestampField } - -workflow_run_fields = { - "id": fields.String, - "version": fields.String, - "status": fields.String, - "error": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, - "total_steps": fields.Integer, - "finished_at": TimestampField -} \ No newline at end of file diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py new file mode 100644 index 0000000000..37751bc70f --- /dev/null +++ b/api/fields/workflow_run_fields.py @@ -0,0 +1,92 @@ +from flask_restful import fields + +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +workflow_run_for_log_fields = { + "id": fields.String, + "version": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_for_list_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.String, + "inputs": fields.String, + "status": fields.String, + "outputs": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='items') +} + +workflow_run_detail_fields = { + "id": fields.String, + "sequence_number": fields.Integer, + "version": fields.String, + "graph": fields.String, + "inputs": fields.String, + "status": fields.String, + "outputs": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, + "total_price": fields.Float, + "currency": fields.String, + "total_steps": fields.Integer, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_at": TimestampField, + "finished_at": TimestampField +} + +workflow_run_node_execution_fields = { + "id": fields.String, + "index": fields.Integer, + "predecessor_node_id": fields.String, + "node_id": fields.String, + "node_type": fields.String, + "title": fields.String, + "inputs": fields.String, + "process_data": fields.String, + "outputs": fields.String, + "status": fields.String, + "error": fields.String, + "elapsed_time": fields.Float, + "execution_metadata": fields.String, + "created_at": TimestampField, + "created_by_role": fields.String, + "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "finished_at": TimestampField +} + +workflow_run_node_execution_list_fields = { + 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), +} diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 7255b4b5fa..5f7ddc7d68 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -88,7 +88,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_run_pkey') ) with op.batch_alter_table('workflow_runs', schema=None) as batch_op: - batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from'], unique=False) + batch_op.create_index('workflow_run_triggerd_from_idx', ['tenant_id', 'app_id', 'triggered_from'], unique=False) op.create_table('workflows', sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), diff --git a/api/models/workflow.py b/api/models/workflow.py index 41266fe9f5..7ea342cda7 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -208,7 +208,7 @@ class WorkflowRun(db.Model): __tablename__ = 'workflow_runs' __table_args__ = ( db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from'), + db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) @@ -236,11 +236,36 @@ class WorkflowRun(db.Model): @property def created_by_account(self): - return Account.query.get(self.created_by) + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None @property - def updated_by_account(self): - return Account.query.get(self.updated_by) + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + + +class WorkflowNodeExecutionTriggeredFrom(Enum): + """ + Workflow Node Execution Triggered From Enum + """ + SINGLE_STEP = 'single-step' + WORKFLOW_RUN = 'workflow-run' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution triggered from value {value}') class WorkflowNodeExecution(db.Model): @@ -323,6 +348,18 @@ class WorkflowNodeExecution(db.Model): created_by = db.Column(UUID, nullable=False) finished_at = db.Column(db.DateTime) + @property + def created_by_account(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return Account.query.get(self.created_by) \ + if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + created_by_role = CreatedByRole.value_of(self.created_by_role) + return EndUser.query.get(self.created_by) \ + if created_by_role == CreatedByRole.END_USER else None + class WorkflowAppLog(db.Model): """ diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py new file mode 100644 index 0000000000..9c898f10fb --- /dev/null +++ b/api/services/workflow_run_service.py @@ -0,0 +1,89 @@ +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.model import App +from models.workflow import WorkflowRun, WorkflowRunTriggeredFrom, WorkflowNodeExecution, \ + WorkflowNodeExecutionTriggeredFrom + + +class WorkflowRunService: + def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + """ + Get debug workflow run list + Only return triggered_from == debugging + + :param app_model: app model + :param args: request args + """ + limit = int(args.get('limit', 20)) + + base_query = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + ) + + if args.get('last_id'): + last_workflow_run = base_query.filter( + WorkflowRun.id == args.get('last_id'), + ).first() + + if not last_workflow_run: + raise ValueError('Last workflow run not exists') + + conversations = base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id + ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + else: + conversations = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + + has_more = False + if len(conversations) == limit: + current_page_first_conversation = conversations[-1] + rest_count = base_query.filter( + WorkflowRun.created_at < current_page_first_conversation.created_at, + WorkflowRun.id != current_page_first_conversation.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination( + data=conversations, + limit=limit, + has_more=has_more + ) + + def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + """ + Get workflow run detail + + :param app_model: app model + :param run_id: workflow run id + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ).first() + + return workflow_run + + def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + """ + Get workflow run node execution list + """ + workflow_run = self.get_workflow_run(app_model, run_id) + + if not workflow_run: + return [] + + node_executions = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ).order_by(WorkflowNodeExecution.index.desc()).all() + + return node_executions From 124aa9db08f90a3fb8900dfed35ce1f018678520 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Feb 2024 21:39:20 +0800 Subject: [PATCH 075/200] lint fix --- api/controllers/console/app/workflow_run.py | 7 +++++-- api/services/workflow_run_service.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 38e3d4d837..8a4c0492a1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -5,8 +5,11 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.workflow_run_fields import workflow_run_detail_fields, workflow_run_pagination_fields, \ - workflow_run_node_execution_list_fields +from fields.workflow_run_fields import ( + workflow_run_detail_fields, + workflow_run_node_execution_list_fields, + workflow_run_pagination_fields, +) from libs.helper import uuid_value from libs.login import login_required from models.model import App, AppMode diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 9c898f10fb..70ce1f2ce0 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,8 +1,12 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.model import App -from models.workflow import WorkflowRun, WorkflowRunTriggeredFrom, WorkflowNodeExecution, \ - WorkflowNodeExecutionTriggeredFrom +from models.workflow import ( + WorkflowNodeExecution, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunTriggeredFrom, +) class WorkflowRunService: From 7724d010b6e4e025e60e135ec85963928fc146c1 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 16:27:41 +0800 Subject: [PATCH 076/200] add app description add update app api --- api/controllers/console/app/app.py | 23 ++++++++++++- api/fields/app_fields.py | 4 +++ .../f9107f83abab_add_desc_for_apps.py | 32 +++++++++++++++++++ api/models/model.py | 4 ++- api/models/workflow.py | 4 ++- api/services/app_service.py | 20 +++++++++++- 6 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 api/migrations/versions/f9107f83abab_add_desc_for_apps.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 898fd4f7c4..98636fa95f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden, BadRequest from controllers.console import api @@ -53,6 +53,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('description', type=str, location='json') parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') @@ -86,6 +87,7 @@ class AppImportApi(Resource): parser = reqparse.RequestParser() parser.add_argument('data', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() @@ -144,6 +146,25 @@ class AppApi(Resource): return app_model + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def put(self, app_model): + """Update app""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app_service = AppService() + app_model = app_service.update_app(app_model, args) + + return app_model + @setup_required @login_required @account_initialization_required diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 75b68d24fc..69ab1d3e3e 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -5,6 +5,7 @@ from libs.helper import TimestampField app_detail_kernel_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -41,6 +42,7 @@ model_config_fields = { app_detail_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -62,6 +64,7 @@ model_config_partial_fields = { app_partial_fields = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, @@ -109,6 +112,7 @@ site_fields = { app_detail_fields_with_site = { 'id': fields.String, 'name': fields.String, + 'description': fields.String, 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py new file mode 100644 index 0000000000..88d77bb320 --- /dev/null +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -0,0 +1,32 @@ +"""add desc for apps + +Revision ID: f9107f83abab +Revises: cc04d0998d4d +Create Date: 2024-02-28 08:16:14.090481 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f9107f83abab' +down_revision = 'cc04d0998d4d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('description') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 713d8da577..8d286d3482 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,7 +14,6 @@ from extensions.ext_database import db from libs.helper import generate_string from .account import Account, Tenant -from .workflow import Workflow, WorkflowRun class DifySetup(db.Model): @@ -59,6 +58,7 @@ class App(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(UUID, nullable=False) name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) @@ -279,6 +279,7 @@ class AppModelConfig(db.Model): @property def workflow(self): if self.workflow_id: + from api.models.workflow import Workflow return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @@ -692,6 +693,7 @@ class Message(db.Model): @property def workflow_run(self): if self.workflow_run_id: + from api.models.workflow import WorkflowRun return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None diff --git a/api/models/workflow.py b/api/models/workflow.py index 7ea342cda7..316d3e623e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db from models.account import Account -from models.model import EndUser class CreatedByRole(Enum): @@ -242,6 +241,7 @@ class WorkflowRun(db.Model): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @@ -356,6 +356,7 @@ class WorkflowNodeExecution(db.Model): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @@ -418,6 +419,7 @@ class WorkflowAppLog(db.Model): @property def created_by_end_user(self): + from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None diff --git a/api/services/app_service.py b/api/services/app_service.py index 5de87dbad5..2e534eae15 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -97,10 +97,11 @@ class AppService: app = App(**app_template['app']) app.name = args['name'] + app.description = args.get('description', '') app.mode = args['mode'] app.icon = args['icon'] app.icon_background = args['icon_background'] - app.tenant_id = account.current_tenant_id + app.tenant_id = tenant_id db.session.add(app) db.session.flush() @@ -145,6 +146,7 @@ class AppService: tenant_id=tenant_id, mode=app_data.get('mode'), name=args.get("name") if args.get("name") else app_data.get('name'), + description=args.get("description") if args.get("description") else app_data.get('description', ''), icon=args.get("icon") if args.get("icon") else app_data.get('icon'), icon_background=args.get("icon_background") if args.get("icon_background") \ else app_data.get('icon_background'), @@ -205,6 +207,22 @@ class AppService: return yaml.dump(export_data) + def update_app(self, app: App, args: dict) -> App: + """ + Update app + :param app: App instance + :param args: request args + :return: App instance + """ + app.name = args.get('name') + app.description = args.get('description', '') + app.icon = args.get('icon') + app.icon_background = args.get('icon_background') + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + def update_app_name(self, app: App, name: str) -> App: """ Update app name From 11337e51c54ce2574dbde767337450567804e18d Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 16:27:49 +0800 Subject: [PATCH 077/200] lint fix --- api/migrations/versions/f9107f83abab_add_desc_for_apps.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py index 88d77bb320..3e5ae0d67d 100644 --- a/api/migrations/versions/f9107f83abab_add_desc_for_apps.py +++ b/api/migrations/versions/f9107f83abab_add_desc_for_apps.py @@ -5,9 +5,8 @@ Revises: cc04d0998d4d Create Date: 2024-02-28 08:16:14.090481 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'f9107f83abab' From 022b7d5dd442621cbb7044df2b7fee6ad2c4bbbe Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 18:24:49 +0800 Subject: [PATCH 078/200] optimize default model exceptions --- api/services/app_service.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 2e534eae15..298cd650df 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,4 +1,5 @@ import json +import logging from datetime import datetime from typing import cast @@ -6,7 +7,7 @@ import yaml from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates -from core.errors.error import ProviderTokenNotInitError +from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -74,7 +75,10 @@ class AppService: tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) - except ProviderTokenNotInitError: + except (ProviderTokenNotInitError, LLMBadRequestError): + model_instance = None + except Exception as e: + logging.exception(e) model_instance = None if model_instance: From dd70aeff247be188c834e8af06efab3c0c0e61be Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 18:27:16 +0800 Subject: [PATCH 079/200] lint fix --- api/services/app_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/app_service.py b/api/services/app_service.py index 298cd650df..374727d2d4 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -7,7 +7,7 @@ import yaml from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates -from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel From 77618823a5c1da589f9d32732d3b8ef0b7907b83 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 22:16:24 +0800 Subject: [PATCH 080/200] add features update api refactor app model config validation --- api/controllers/console/app/model_config.py | 43 +- api/core/apps/__init__.py | 0 .../apps/app_config_validators/__init__.py | 0 .../advanced_chat_app.py | 54 ++ .../app_config_validators/agent_chat_app.py | 82 +++ .../apps/app_config_validators/chat_app.py | 82 +++ .../app_config_validators/completion_app.py | 67 +++ .../app_config_validators/workflow_app.py | 34 ++ api/core/apps/config_validators/__init__.py | 0 api/core/apps/config_validators/agent.py | 82 +++ api/core/apps/config_validators/dataset.py | 141 +++++ .../config_validators/external_data_tools.py | 40 ++ .../apps/config_validators/file_upload.py | 38 ++ api/core/apps/config_validators/model.py | 83 +++ api/core/apps/config_validators/moderation.py | 36 ++ .../apps/config_validators/more_like_this.py | 26 + .../config_validators/opening_statement.py | 29 + api/core/apps/config_validators/prompt.py | 87 +++ .../config_validators/retriever_resource.py | 26 + .../apps/config_validators/speech_to_text.py | 26 + .../config_validators/suggested_questions.py | 26 + .../apps/config_validators/text_to_speech.py | 30 + .../apps/config_validators/user_input_form.py | 62 ++ api/services/app_model_config_service.py | 539 +----------------- api/services/completion_service.py | 11 +- api/services/workflow_service.py | 2 +- 26 files changed, 1115 insertions(+), 531 deletions(-) create mode 100644 api/core/apps/__init__.py create mode 100644 api/core/apps/app_config_validators/__init__.py create mode 100644 api/core/apps/app_config_validators/advanced_chat_app.py create mode 100644 api/core/apps/app_config_validators/agent_chat_app.py create mode 100644 api/core/apps/app_config_validators/chat_app.py create mode 100644 api/core/apps/app_config_validators/completion_app.py create mode 100644 api/core/apps/app_config_validators/workflow_app.py create mode 100644 api/core/apps/config_validators/__init__.py create mode 100644 api/core/apps/config_validators/agent.py create mode 100644 api/core/apps/config_validators/dataset.py create mode 100644 api/core/apps/config_validators/external_data_tools.py create mode 100644 api/core/apps/config_validators/file_upload.py create mode 100644 api/core/apps/config_validators/model.py create mode 100644 api/core/apps/config_validators/moderation.py create mode 100644 api/core/apps/config_validators/more_like_this.py create mode 100644 api/core/apps/config_validators/opening_statement.py create mode 100644 api/core/apps/config_validators/prompt.py create mode 100644 api/core/apps/config_validators/retriever_resource.py create mode 100644 api/core/apps/config_validators/speech_to_text.py create mode 100644 api/core/apps/config_validators/suggested_questions.py create mode 100644 api/core/apps/config_validators/text_to_speech.py create mode 100644 api/core/apps/config_validators/user_input_form.py diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 0f8bc28f6f..0ae9f5e546 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,7 @@ import json from flask import request from flask_login import current_user -from flask_restful import Resource +from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -14,7 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required -from models.model import AppModelConfig +from models.model import AppModelConfig, AppMode from services.app_model_config_service import AppModelConfigService @@ -23,15 +23,14 @@ class ModelConfigResource(Resource): @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, - account=current_user, config=request.json, - app_mode=app_model.mode + app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( @@ -138,4 +137,38 @@ class ModelConfigResource(Resource): return {'result': 'success'} +class FeaturesResource(Resource): + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def put(self, app_model): + """Get app features""" + parser = reqparse.RequestParser() + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + model_configuration = AppModelConfigService.validate_features( + tenant_id=current_user.current_tenant_id, + config=args.get('features'), + app_mode=AppMode.value_of(app_model.mode) + ) + + # update config + app_model_config = app_model.app_model_config + app_model_config.from_model_config_dict(model_configuration) + db.session.commit() + + app_model_config_was_updated.send( + app_model, + app_model_config=app_model_config + ) + + return { + 'result': 'success' + } + + api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(FeaturesResource, '/apps//features') diff --git a/api/core/apps/__init__.py b/api/core/apps/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/app_config_validators/__init__.py b/api/core/apps/app_config_validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/app_config_validators/advanced_chat_app.py b/api/core/apps/app_config_validators/advanced_chat_app.py new file mode 100644 index 0000000000..dc7664b844 --- /dev/null +++ b/api/core/apps/app_config_validators/advanced_chat_app.py @@ -0,0 +1,54 @@ +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator + + +class AdvancedChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/agent_chat_app.py b/api/core/apps/app_config_validators/agent_chat_app.py new file mode 100644 index 0000000000..d507fae685 --- /dev/null +++ b/api/core/apps/app_config_validators/agent_chat_app.py @@ -0,0 +1,82 @@ +from core.apps.config_validators.agent import AgentValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class AgentChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/chat_app.py b/api/core/apps/app_config_validators/chat_app.py new file mode 100644 index 0000000000..83c792e610 --- /dev/null +++ b/api/core/apps/app_config_validators/chat_app.py @@ -0,0 +1,82 @@ +from core.apps.config_validators.dataset import DatasetValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.opening_statement import OpeningStatementValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.retriever_resource import RetrieverResourceValidator +from core.apps.config_validators.speech_to_text import SpeechToTextValidator +from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class ChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/completion_app.py b/api/core/apps/app_config_validators/completion_app.py new file mode 100644 index 0000000000..00371f8d05 --- /dev/null +++ b/api/core/apps/app_config_validators/completion_app.py @@ -0,0 +1,67 @@ +from core.apps.config_validators.dataset import DatasetValidator +from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.model import ModelValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.more_like_this import MoreLikeThisValidator +from core.apps.config_validators.prompt import PromptValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.apps.config_validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +class CompletionAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/app_config_validators/workflow_app.py b/api/core/apps/app_config_validators/workflow_app.py new file mode 100644 index 0000000000..545d3d79a3 --- /dev/null +++ b/api/core/apps/app_config_validators/workflow_app.py @@ -0,0 +1,34 @@ +from core.apps.config_validators.file_upload import FileUploadValidator +from core.apps.config_validators.moderation import ModerationValidator +from core.apps.config_validators.text_to_speech import TextToSpeechValidator + + +class WorkflowAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/apps/config_validators/__init__.py b/api/core/apps/config_validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py new file mode 100644 index 0000000000..69f9338080 --- /dev/null +++ b/api/core/apps/config_validators/agent.py @@ -0,0 +1,82 @@ +import uuid +from typing import Tuple + +from core.agent.agent_executor import PlanningStrategy +from core.apps.config_validators.dataset import DatasetValidator + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py new file mode 100644 index 0000000000..32db038c21 --- /dev/null +++ b/api/core/apps/config_validators/dataset.py @@ -0,0 +1,141 @@ +import uuid +from typing import Tuple + +from core.agent.agent_executor import PlanningStrategy +from models.model import AppMode +from services.dataset_service import DatasetService + + +class DatasetValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for dataset feature + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) + + # dataset_configs + if not config.get("dataset_configs"): + config["dataset_configs"] = {'retrieval_model': 'single'} + + if not config["dataset_configs"].get("datasets"): + config["dataset_configs"]["datasets"] = { + "strategy": "router", + "datasets": [] + } + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if config["dataset_configs"]['retrieval_model'] == 'multiple': + if not config["dataset_configs"]['reranking_model']: + raise ValueError("reranking_model has not been set") + if not isinstance(config["dataset_configs"]['reranking_model'], dict): + raise ValueError("reranking_model must be of object type") + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] + + @classmethod + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: + """ + Extract dataset config for legacy compatibility + + :param tenant_id: tenant ID + :param app_mode: app mode + :param config: app model config args + """ + # Extract dataset config for legacy compatibility + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + # enabled + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + # tools + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + # strategy + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + has_datasets = False + if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key == "dataset": + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not cls.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + + has_datasets = True + + need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] + + if need_manual_query_datasets and app_mode == AppMode.COMPLETION: + # Only check when mode is completion + dataset_query_variable = config.get("dataset_query_variable") + + if not dataset_query_variable: + raise ValueError("Dataset query variable is required when dataset is exist") + + return config + + @classmethod + def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool: + # verify if the dataset ID exists + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + return False + + if dataset.tenant_id != tenant_id: + return False + + return True diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/apps/config_validators/external_data_tools.py new file mode 100644 index 0000000000..5412366a89 --- /dev/null +++ b/api/core/apps/config_validators/external_data_tools.py @@ -0,0 +1,40 @@ +from typing import Tuple + +from core.external_data_tool.factory import ExternalDataToolFactory + + +class ExternalDataToolsValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["external_data_tools"] diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/apps/config_validators/file_upload.py new file mode 100644 index 0000000000..f9adbfdf7d --- /dev/null +++ b/api/core/apps/config_validators/file_upload.py @@ -0,0 +1,38 @@ +from typing import Tuple + + +class FileUploadValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for file upload feature + + :param config: app model config args + """ + if not config.get("file_upload"): + config["file_upload"] = {} + + if not isinstance(config["file_upload"], dict): + raise ValueError("file_upload must be of dict type") + + # check image config + if not config["file_upload"].get("image"): + config["file_upload"]["image"] = {"enabled": False} + + if config['file_upload']['image']['enabled']: + number_limits = config['file_upload']['image']['number_limits'] + if number_limits < 1 or number_limits > 6: + raise ValueError("number_limits must be in [1, 6]") + + detail = config['file_upload']['image']['detail'] + if detail not in ['high', 'low']: + raise ValueError("detail must be in ['high', 'low']") + + transfer_methods = config['file_upload']['image']['transfer_methods'] + if not isinstance(transfer_methods, list): + raise ValueError("transfer_methods must be of list type") + for method in transfer_methods: + if method not in ['remote_url', 'local_file']: + raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + + return config, ["file_upload"] diff --git a/api/core/apps/config_validators/model.py b/api/core/apps/config_validators/model.py new file mode 100644 index 0000000000..091eec4683 --- /dev/null +++ b/api/core/apps/config_validators/model.py @@ -0,0 +1,83 @@ +from typing import Tuple + +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager + + +class ModelValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for model config + + :param tenant_id: tenant id + :param config: app model config args + """ + if 'model' not in config: + raise ValueError("model is required") + + if not isinstance(config["model"], dict): + raise ValueError("model must be of object type") + + # model.provider + provider_entities = model_provider_factory.get_providers() + model_provider_names = [provider.provider for provider in provider_entities] + if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") + + # model.name + if 'name' not in config["model"]: + raise ValueError("model.name is required") + + provider_manager = ProviderManager() + models = provider_manager.get_configurations(tenant_id).get_models( + provider=config["model"]["provider"], + model_type=ModelType.LLM + ) + + if not models: + raise ValueError("model.name must be in the specified model list") + + model_ids = [m.model for m in models] + if config["model"]["name"] not in model_ids: + raise ValueError("model.name must be in the specified model list") + + model_mode = None + for model in models: + if model.model == config["model"]["name"]: + model_mode = model.model_properties.get(ModelPropertyKey.MODE) + break + + # model.mode + if model_mode: + config['model']["mode"] = model_mode + else: + config['model']["mode"] = "completion" + + # model.completion_params + if 'completion_params' not in config["model"]: + raise ValueError("model.completion_params is required") + + config["model"]["completion_params"] = cls.validate_model_completion_params( + config["model"]["completion_params"] + ) + + return config, ["model"] + + @classmethod + def validate_model_completion_params(cls, cp: dict) -> dict: + # model.completion_params + if not isinstance(cp, dict): + raise ValueError("model.completion_params must be of object type") + + # stop + if 'stop' not in cp: + cp["stop"] = [] + elif not isinstance(cp["stop"], list): + raise ValueError("stop in model.completion_params must be of list type") + + if len(cp["stop"]) > 4: + raise ValueError("stop sequences must be less than 4") + + return cp diff --git a/api/core/apps/config_validators/moderation.py b/api/core/apps/config_validators/moderation.py new file mode 100644 index 0000000000..1962f87aa9 --- /dev/null +++ b/api/core/apps/config_validators/moderation.py @@ -0,0 +1,36 @@ +import logging +from typing import Tuple + +from core.moderation.factory import ModerationFactory + +logger = logging.getLogger(__name__) + + +class ModerationValidator: + @classmethod + def validate_and_set_defaults(cls, tenant_id, config: dict) -> Tuple[dict, list[str]]: + if not config.get("sensitive_word_avoidance"): + config["sensitive_word_avoidance"] = { + "enabled": False + } + + if not isinstance(config["sensitive_word_avoidance"], dict): + raise ValueError("sensitive_word_avoidance must be of dict type") + + if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: + config["sensitive_word_avoidance"]["enabled"] = False + + if config["sensitive_word_avoidance"]["enabled"]: + if not config["sensitive_word_avoidance"].get("type"): + raise ValueError("sensitive_word_avoidance.type is required") + + typ = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] + + ModerationFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["sensitive_word_avoidance"] diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/apps/config_validators/more_like_this.py new file mode 100644 index 0000000000..60dc4a0562 --- /dev/null +++ b/api/core/apps/config_validators/more_like_this.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class MoreLikeThisValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for more like this feature + + :param config: app model config args + """ + if not config.get("more_like_this"): + config["more_like_this"] = { + "enabled": False + } + + if not isinstance(config["more_like_this"], dict): + raise ValueError("more_like_this must be of dict type") + + if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: + config["more_like_this"]["enabled"] = False + + if not isinstance(config["more_like_this"]["enabled"], bool): + raise ValueError("enabled in more_like_this must be of boolean type") + + return config, ["more_like_this"] diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/apps/config_validators/opening_statement.py new file mode 100644 index 0000000000..3f69e0e946 --- /dev/null +++ b/api/core/apps/config_validators/opening_statement.py @@ -0,0 +1,29 @@ +from typing import Tuple + + +class OpeningStatementValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for opening statement feature + + :param config: app model config args + """ + if not config.get("opening_statement"): + config["opening_statement"] = "" + + if not isinstance(config["opening_statement"], str): + raise ValueError("opening_statement must be of string type") + + # suggested_questions + if not config.get("suggested_questions"): + config["suggested_questions"] = [] + + if not isinstance(config["suggested_questions"], list): + raise ValueError("suggested_questions must be of list type") + + for question in config["suggested_questions"]: + if not isinstance(question, str): + raise ValueError("Elements in suggested_questions list must be of string type") + + return config, ["opening_statement", "suggested_questions"] diff --git a/api/core/apps/config_validators/prompt.py b/api/core/apps/config_validators/prompt.py new file mode 100644 index 0000000000..815706b10b --- /dev/null +++ b/api/core/apps/config_validators/prompt.py @@ -0,0 +1,87 @@ +from typing import Tuple + +from core.entities.application_entities import PromptTemplateEntity +from core.prompt.simple_prompt_transform import ModelMode +from models.model import AppMode + + +class PromptValidator: + @classmethod + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + """ + Validate pre_prompt and set defaults for prompt feature + depending on the config['model'] + + :param app_mode: app mode + :param config: app model config args + """ + if not config.get("prompt_type"): + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + + prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + if config['prompt_type'] not in prompt_type_vals: + raise ValueError(f"prompt_type must be in {prompt_type_vals}") + + # chat_prompt_config + if not config.get("chat_prompt_config"): + config["chat_prompt_config"] = {} + + if not isinstance(config["chat_prompt_config"], dict): + raise ValueError("chat_prompt_config must be of object type") + + # completion_prompt_config + if not config.get("completion_prompt_config"): + config["completion_prompt_config"] = {} + + if not isinstance(config["completion_prompt_config"], dict): + raise ValueError("completion_prompt_config must be of object type") + + if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config['chat_prompt_config'] and not config['completion_prompt_config']: + raise ValueError("chat_prompt_config or completion_prompt_config is required " + "when prompt_type is advanced") + + model_mode_vals = [mode.value for mode in ModelMode] + if config['model']["mode"] not in model_mode_vals: + raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") + + if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: + user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] + assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + + if not user_prefix: + config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + + if not assistant_prefix: + config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + + if config['model']["mode"] == ModelMode.CHAT.value: + prompt_list = config['chat_prompt_config']['prompt'] + + if len(prompt_list) > 10: + raise ValueError("prompt messages must be less than 10") + else: + # pre_prompt, for simple mode + if not config.get("pre_prompt"): + config["pre_prompt"] = "" + + if not isinstance(config["pre_prompt"], str): + raise ValueError("pre_prompt must be of string type") + + return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] + + @classmethod + def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: + """ + Validate post_prompt and set defaults for prompt feature + + :param config: app model config args + """ + # post_prompt + if not config.get("post_prompt"): + config["post_prompt"] = "" + + if not isinstance(config["post_prompt"], str): + raise ValueError("post_prompt must be of string type") + + return config \ No newline at end of file diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/apps/config_validators/retriever_resource.py new file mode 100644 index 0000000000..a8bcd60abe --- /dev/null +++ b/api/core/apps/config_validators/retriever_resource.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class RetrieverResourceValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for retriever resource feature + + :param config: app model config args + """ + if not config.get("retriever_resource"): + config["retriever_resource"] = { + "enabled": False + } + + if not isinstance(config["retriever_resource"], dict): + raise ValueError("retriever_resource must be of dict type") + + if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: + config["retriever_resource"]["enabled"] = False + + if not isinstance(config["retriever_resource"]["enabled"], bool): + raise ValueError("enabled in retriever_resource must be of boolean type") + + return config, ["retriever_resource"] diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/apps/config_validators/speech_to_text.py new file mode 100644 index 0000000000..577bef0e59 --- /dev/null +++ b/api/core/apps/config_validators/speech_to_text.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class SpeechToTextValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for speech to text feature + + :param config: app model config args + """ + if not config.get("speech_to_text"): + config["speech_to_text"] = { + "enabled": False + } + + if not isinstance(config["speech_to_text"], dict): + raise ValueError("speech_to_text must be of dict type") + + if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: + config["speech_to_text"]["enabled"] = False + + if not isinstance(config["speech_to_text"]["enabled"], bool): + raise ValueError("enabled in speech_to_text must be of boolean type") + + return config, ["speech_to_text"] diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/apps/config_validators/suggested_questions.py new file mode 100644 index 0000000000..938b66bb6e --- /dev/null +++ b/api/core/apps/config_validators/suggested_questions.py @@ -0,0 +1,26 @@ +from typing import Tuple + + +class SuggestedQuestionsValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for suggested questions feature + + :param config: app model config args + """ + if not config.get("suggested_questions_after_answer"): + config["suggested_questions_after_answer"] = { + "enabled": False + } + + if not isinstance(config["suggested_questions_after_answer"], dict): + raise ValueError("suggested_questions_after_answer must be of dict type") + + if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: + config["suggested_questions_after_answer"]["enabled"] = False + + if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): + raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + + return config, ["suggested_questions_after_answer"] diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/apps/config_validators/text_to_speech.py new file mode 100644 index 0000000000..efe34a8a3e --- /dev/null +++ b/api/core/apps/config_validators/text_to_speech.py @@ -0,0 +1,30 @@ +from typing import Tuple + + +class TextToSpeechValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for text to speech feature + + :param config: app model config args + """ + if not config.get("text_to_speech"): + config["text_to_speech"] = { + "enabled": False, + "voice": "", + "language": "" + } + + if not isinstance(config["text_to_speech"], dict): + raise ValueError("text_to_speech must be of dict type") + + if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: + config["text_to_speech"]["enabled"] = False + config["text_to_speech"]["voice"] = "" + config["text_to_speech"]["language"] = "" + + if not isinstance(config["text_to_speech"]["enabled"], bool): + raise ValueError("enabled in text_to_speech must be of boolean type") + + return config, ["text_to_speech"] diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/apps/config_validators/user_input_form.py new file mode 100644 index 0000000000..7116c55afc --- /dev/null +++ b/api/core/apps/config_validators/user_input_form.py @@ -0,0 +1,62 @@ +import re +from typing import Tuple + + +class UserInputFormValidator: + @classmethod + def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 34b6d62d51..c1e0ecebe8 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,528 +1,29 @@ -import re -import uuid - -from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import AppMode -from core.external_data_tool.factory import ExternalDataToolFactory -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers import model_provider_factory -from core.moderation.factory import ModerationFactory -from core.provider_manager import ProviderManager -from models.account import Account +from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator +from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator +from core.apps.app_config_validators.chat_app import ChatAppConfigValidator +from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator +from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator from models.model import AppMode -from services.dataset_service import DatasetService - -SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] class AppModelConfigService: - @classmethod - def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool: - # verify if the dataset ID exists - dataset = DatasetService.get_dataset(dataset_id) - - if not dataset: - return False - - if dataset.tenant_id != account.current_tenant_id: - return False - - return True @classmethod - def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict: - # 6. model.completion_params - if not isinstance(cp, dict): - raise ValueError("model.completion_params must be of object type") - - # stop - if 'stop' not in cp: - cp["stop"] = [] - elif not isinstance(cp["stop"], list): - raise ValueError("stop in model.completion_params must be of list type") - - if len(cp["stop"]) > 4: - raise ValueError("stop sequences must be less than 4") - - return cp - - @classmethod - def validate_configuration(cls, tenant_id: str, account: Account, config: dict, app_mode: str) -> dict: - # opening_statement - if 'opening_statement' not in config or not config["opening_statement"]: - config["opening_statement"] = "" - - if not isinstance(config["opening_statement"], str): - raise ValueError("opening_statement must be of string type") - - # suggested_questions - if 'suggested_questions' not in config or not config["suggested_questions"]: - config["suggested_questions"] = [] - - if not isinstance(config["suggested_questions"], list): - raise ValueError("suggested_questions must be of list type") - - for question in config["suggested_questions"]: - if not isinstance(question, str): - raise ValueError("Elements in suggested_questions list must be of string type") - - # suggested_questions_after_answer - if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]: - config["suggested_questions_after_answer"] = { - "enabled": False - } - - if not isinstance(config["suggested_questions_after_answer"], dict): - raise ValueError("suggested_questions_after_answer must be of dict type") - - if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: - config["suggested_questions_after_answer"]["enabled"] = False - - if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): - raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") - - # speech_to_text - if 'speech_to_text' not in config or not config["speech_to_text"]: - config["speech_to_text"] = { - "enabled": False - } - - if not isinstance(config["speech_to_text"], dict): - raise ValueError("speech_to_text must be of dict type") - - if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]: - config["speech_to_text"]["enabled"] = False - - if not isinstance(config["speech_to_text"]["enabled"], bool): - raise ValueError("enabled in speech_to_text must be of boolean type") - - # text_to_speech - if 'text_to_speech' not in config or not config["text_to_speech"]: - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } - - if not isinstance(config["text_to_speech"], dict): - raise ValueError("text_to_speech must be of dict type") - - if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]: - config["text_to_speech"]["enabled"] = False - config["text_to_speech"]["voice"] = "" - config["text_to_speech"]["language"] = "" - - if not isinstance(config["text_to_speech"]["enabled"], bool): - raise ValueError("enabled in text_to_speech must be of boolean type") - - # return retriever resource - if 'retriever_resource' not in config or not config["retriever_resource"]: - config["retriever_resource"] = { - "enabled": False - } - - if not isinstance(config["retriever_resource"], dict): - raise ValueError("retriever_resource must be of dict type") - - if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: - config["retriever_resource"]["enabled"] = False - - if not isinstance(config["retriever_resource"]["enabled"], bool): - raise ValueError("enabled in retriever_resource must be of boolean type") - - # more_like_this - if 'more_like_this' not in config or not config["more_like_this"]: - config["more_like_this"] = { - "enabled": False - } - - if not isinstance(config["more_like_this"], dict): - raise ValueError("more_like_this must be of dict type") - - if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: - config["more_like_this"]["enabled"] = False - - if not isinstance(config["more_like_this"]["enabled"], bool): - raise ValueError("enabled in more_like_this must be of boolean type") - - # model - if 'model' not in config: - raise ValueError("model is required") - - if not isinstance(config["model"], dict): - raise ValueError("model must be of object type") - - # model.provider - provider_entities = model_provider_factory.get_providers() - model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: - raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") - - # model.name - if 'name' not in config["model"]: - raise ValueError("model.name is required") - - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM - ) - if not models: - raise ValueError("model.name must be in the specified model list") - - model_ids = [m.model for m in models] - if config["model"]["name"] not in model_ids: - raise ValueError("model.name must be in the specified model list") - - model_mode = None - for model in models: - if model.model == config["model"]["name"]: - model_mode = model.model_properties.get(ModelPropertyKey.MODE) - break - - # model.mode - if model_mode: - config['model']["mode"] = model_mode + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.CHAT: + return ChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.AGENT_CHAT: + return AgentChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.COMPLETION: + return CompletionAppConfigValidator.config_validate(tenant_id, config) else: - config['model']["mode"] = "completion" - - # model.completion_params - if 'completion_params' not in config["model"]: - raise ValueError("model.completion_params is required") - - config["model"]["completion_params"] = cls.validate_model_completion_params( - config["model"]["completion_params"], - config["model"]["name"] - ) - - # user_input_form - if "user_input_form" not in config or not config["user_input_form"]: - config["user_input_form"] = [] - - if not isinstance(config["user_input_form"], list): - raise ValueError("user_input_form must be a list of objects") - - variables = [] - for item in config["user_input_form"]: - key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") - - form_item = item[key] - if 'label' not in form_item: - raise ValueError("label is required in user_input_form") - - if not isinstance(form_item["label"], str): - raise ValueError("label in user_input_form must be of string type") - - if 'variable' not in form_item: - raise ValueError("variable is required in user_input_form") - - if not isinstance(form_item["variable"], str): - raise ValueError("variable in user_input_form must be of string type") - - pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") - if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") - - variables.append(form_item["variable"]) - - if 'required' not in form_item or not form_item["required"]: - form_item["required"] = False - - if not isinstance(form_item["required"], bool): - raise ValueError("required in user_input_form must be of boolean type") - - if key == "select": - if 'options' not in form_item or not form_item["options"]: - form_item["options"] = [] - - if not isinstance(form_item["options"], list): - raise ValueError("options in user_input_form must be a list of strings") - - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: - raise ValueError("default value in user_input_form must be in the options list") - - # pre_prompt - if "pre_prompt" not in config or not config["pre_prompt"]: - config["pre_prompt"] = "" - - if not isinstance(config["pre_prompt"], str): - raise ValueError("pre_prompt must be of string type") - - # agent_mode - if "agent_mode" not in config or not config["agent_mode"]: - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]: - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]: - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in SUPPORT_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not cls.is_dataset_exists(account, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - # dataset_query_variable - cls.is_dataset_query_variable_valid(config, app_mode) - - # advanced prompt validation - cls.is_advanced_prompt_valid(config, app_mode) - - # external data tools validation - cls.is_external_data_tools_valid(tenant_id, config) - - # moderation validation - cls.is_moderation_valid(tenant_id, config) - - # file upload validation - cls.is_file_upload_valid(config) - - # Filter out extra parameters - filtered_config = { - "opening_statement": config["opening_statement"], - "suggested_questions": config["suggested_questions"], - "suggested_questions_after_answer": config["suggested_questions_after_answer"], - "speech_to_text": config["speech_to_text"], - "text_to_speech": config["text_to_speech"], - "retriever_resource": config["retriever_resource"], - "more_like_this": config["more_like_this"], - "sensitive_word_avoidance": config["sensitive_word_avoidance"], - "external_data_tools": config["external_data_tools"], - "model": { - "provider": config["model"]["provider"], - "name": config["model"]["name"], - "mode": config['model']["mode"], - "completion_params": config["model"]["completion_params"] - }, - "user_input_form": config["user_input_form"], - "dataset_query_variable": config.get('dataset_query_variable'), - "pre_prompt": config["pre_prompt"], - "agent_mode": config["agent_mode"], - "prompt_type": config["prompt_type"], - "chat_prompt_config": config["chat_prompt_config"], - "completion_prompt_config": config["completion_prompt_config"], - "dataset_configs": config["dataset_configs"], - "file_upload": config["file_upload"] - } - - return filtered_config + raise ValueError(f"Invalid app mode: {app_mode}") @classmethod - def is_moderation_valid(cls, tenant_id: str, config: dict): - if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]: - config["sensitive_word_avoidance"] = { - "enabled": False - } - - if not isinstance(config["sensitive_word_avoidance"], dict): - raise ValueError("sensitive_word_avoidance must be of dict type") - - if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]: - config["sensitive_word_avoidance"]["enabled"] = False - - if not config["sensitive_word_avoidance"]["enabled"]: - return - - if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]: - raise ValueError("sensitive_word_avoidance.type is required") - - type = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] - - ModerationFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_file_upload_valid(cls, config: dict): - if 'file_upload' not in config or not config["file_upload"]: - config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if 'image' not in config["file_upload"] or not config["file_upload"]["image"]: - config["file_upload"]["image"] = {"enabled": False} - - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config['file_upload']['image']['transfer_methods'] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in ['remote_url', 'local_file']: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") - - @classmethod - def is_external_data_tools_valid(cls, tenant_id: str, config: dict): - if 'external_data_tools' not in config or not config["external_data_tools"]: - config["external_data_tools"] = [] - - if not isinstance(config["external_data_tools"], list): - raise ValueError("external_data_tools must be of list type") - - for tool in config["external_data_tools"]: - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - - if not tool["enabled"]: - continue - - if "type" not in tool or not tool["type"]: - raise ValueError("external_data_tools[].type is required") - - type = tool["type"] - config = tool["config"] - - ExternalDataToolFactory.validate_config( - name=type, - tenant_id=tenant_id, - config=config - ) - - @classmethod - def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None: - # Only check when mode is completion - if mode != 'completion': - return - - agent_mode = config.get("agent_mode", {}) - tools = agent_mode.get("tools", []) - dataset_exists = "dataset" in str(tools) - - dataset_query_variable = config.get("dataset_query_variable") - - if dataset_exists and not dataset_query_variable: - raise ValueError("Dataset query variable is required when dataset is exist") - - @classmethod - def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: - # prompt_type - if 'prompt_type' not in config or not config["prompt_type"]: - config["prompt_type"] = "simple" - - if config['prompt_type'] not in ['simple', 'advanced']: - raise ValueError("prompt_type must be in ['simple', 'advanced']") - - # chat_prompt_config - if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: - config["chat_prompt_config"] = {} - - if not isinstance(config["chat_prompt_config"], dict): - raise ValueError("chat_prompt_config must be of object type") - - # completion_prompt_config - if 'completion_prompt_config' not in config or not config["completion_prompt_config"]: - config["completion_prompt_config"] = {} - - if not isinstance(config["completion_prompt_config"], dict): - raise ValueError("completion_prompt_config must be of object type") - - # dataset_configs - if 'dataset_configs' not in config or not config["dataset_configs"]: - config["dataset_configs"] = {'retrieval_model': 'single'} - - if 'datasets' not in config["dataset_configs"] or not config["dataset_configs"]["datasets"]: - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config["dataset_configs"]['retrieval_model'] == 'multiple': - if not config["dataset_configs"]['reranking_model']: - raise ValueError("reranking_model has not been set") - if not isinstance(config["dataset_configs"]['reranking_model'], dict): - raise ValueError("reranking_model must be of object type") - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if config['prompt_type'] == 'advanced': - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") - - if config['model']["mode"] not in ['chat', 'completion']: - raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - - if app_mode == AppMode.CHAT.value and config['model']["mode"] == "completion": - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] - - if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' - - if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' - - if config['model']["mode"] == "chat": - prompt_list = config['chat_prompt_config']['prompt'] - - if len(prompt_list) > 10: - raise ValueError("prompt messages must be less than 10") + def validate_features(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + if app_mode == AppMode.ADVANCED_CHAT: + return AdvancedChatAppConfigValidator.config_validate(tenant_id, config) + elif app_mode == AppMode.WORKFLOW: + return WorkflowAppConfigValidator.config_validate(tenant_id, config) + else: + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index cbfbe9ef41..6dd729694b 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -5,10 +5,11 @@ from typing import Any, Union from sqlalchemy import and_ from core.application_manager import ApplicationManager +from core.apps.config_validators.model import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message +from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message, AppMode from services.app_model_config_service import AppModelConfigService from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError @@ -88,9 +89,8 @@ class CompletionService: if 'completion_params' not in args['model_config']['model']: raise ValueError('model_config.model.completion_params is required') - completion_params = AppModelConfigService.validate_model_completion_params( - cp=args['model_config']['model']['completion_params'], - model_name=app_model_config.model_dict["name"] + completion_params = ModelValidator.validate_model_completion_params( + cp=args['model_config']['model']['completion_params'] ) app_model_config_model = app_model_config.model_dict @@ -115,9 +115,8 @@ class CompletionService: # validate config model_config = AppModelConfigService.validate_configuration( tenant_id=app_model.tenant_id, - account=user, config=args['model_config'], - app_mode=app_model.mode + app_mode=AppMode.value_of(app_model.mode) ) app_model_config = AppModelConfig( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ae6e4c46d3..5a9234c70a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -96,7 +96,7 @@ class WorkflowService: if not draft_workflow: raise ValueError('No valid workflow found.') - # TODO check if the workflow is valid + # TODO check if the workflow is valid, basic check # create new workflow workflow = Workflow( From d741527ae4b6f7257c9ceb243f8c2190fa226632 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Feb 2024 22:16:36 +0800 Subject: [PATCH 081/200] lint --- api/controllers/console/app/model_config.py | 2 +- api/core/apps/config_validators/agent.py | 3 +-- api/core/apps/config_validators/dataset.py | 3 +-- api/core/apps/config_validators/external_data_tools.py | 3 +-- api/core/apps/config_validators/file_upload.py | 3 +-- api/core/apps/config_validators/model.py | 5 ++--- api/core/apps/config_validators/moderation.py | 3 +-- api/core/apps/config_validators/more_like_this.py | 3 +-- api/core/apps/config_validators/opening_statement.py | 3 +-- api/core/apps/config_validators/prompt.py | 3 +-- api/core/apps/config_validators/retriever_resource.py | 3 +-- api/core/apps/config_validators/speech_to_text.py | 3 +-- api/core/apps/config_validators/suggested_questions.py | 3 +-- api/core/apps/config_validators/text_to_speech.py | 3 +-- api/core/apps/config_validators/user_input_form.py | 3 +-- api/services/completion_service.py | 2 +- 16 files changed, 17 insertions(+), 31 deletions(-) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 0ae9f5e546..d822f859bc 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -14,7 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required -from models.model import AppModelConfig, AppMode +from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py index 69f9338080..c6584d2903 100644 --- a/api/core/apps/config_validators/agent.py +++ b/api/core/apps/config_validators/agent.py @@ -1,5 +1,4 @@ import uuid -from typing import Tuple from core.agent.agent_executor import PlanningStrategy from core.apps.config_validators.dataset import DatasetValidator @@ -9,7 +8,7 @@ OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_dat class AgentValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for agent feature diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py index 32db038c21..9846f9085c 100644 --- a/api/core/apps/config_validators/dataset.py +++ b/api/core/apps/config_validators/dataset.py @@ -1,5 +1,4 @@ import uuid -from typing import Tuple from core.agent.agent_executor import PlanningStrategy from models.model import AppMode @@ -8,7 +7,7 @@ from services.dataset_service import DatasetService class DatasetValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for dataset feature diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/apps/config_validators/external_data_tools.py index 5412366a89..02ecc8d715 100644 --- a/api/core/apps/config_validators/external_data_tools.py +++ b/api/core/apps/config_validators/external_data_tools.py @@ -1,11 +1,10 @@ -from typing import Tuple from core.external_data_tool.factory import ExternalDataToolFactory class ExternalDataToolsValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for external data fetch feature diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/apps/config_validators/file_upload.py index f9adbfdf7d..419465bd51 100644 --- a/api/core/apps/config_validators/file_upload.py +++ b/api/core/apps/config_validators/file_upload.py @@ -1,9 +1,8 @@ -from typing import Tuple class FileUploadValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature diff --git a/api/core/apps/config_validators/model.py b/api/core/apps/config_validators/model.py index 091eec4683..1d86fbaf04 100644 --- a/api/core/apps/config_validators/model.py +++ b/api/core/apps/config_validators/model.py @@ -1,13 +1,12 @@ -from typing import Tuple -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager class ModelValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for model config diff --git a/api/core/apps/config_validators/moderation.py b/api/core/apps/config_validators/moderation.py index 1962f87aa9..4813385588 100644 --- a/api/core/apps/config_validators/moderation.py +++ b/api/core/apps/config_validators/moderation.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple from core.moderation.factory import ModerationFactory @@ -8,7 +7,7 @@ logger = logging.getLogger(__name__) class ModerationValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id, config: dict) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = { "enabled": False diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/apps/config_validators/more_like_this.py index 60dc4a0562..1c1bac9de6 100644 --- a/api/core/apps/config_validators/more_like_this.py +++ b/api/core/apps/config_validators/more_like_this.py @@ -1,9 +1,8 @@ -from typing import Tuple class MoreLikeThisValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for more like this feature diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/apps/config_validators/opening_statement.py index 3f69e0e946..f919230e0d 100644 --- a/api/core/apps/config_validators/opening_statement.py +++ b/api/core/apps/config_validators/opening_statement.py @@ -1,9 +1,8 @@ -from typing import Tuple class OpeningStatementValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for opening statement feature diff --git a/api/core/apps/config_validators/prompt.py b/api/core/apps/config_validators/prompt.py index 815706b10b..288a523415 100644 --- a/api/core/apps/config_validators/prompt.py +++ b/api/core/apps/config_validators/prompt.py @@ -1,4 +1,3 @@ -from typing import Tuple from core.entities.application_entities import PromptTemplateEntity from core.prompt.simple_prompt_transform import ModelMode @@ -7,7 +6,7 @@ from models.model import AppMode class PromptValidator: @classmethod - def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ Validate pre_prompt and set defaults for prompt feature depending on the config['model'] diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/apps/config_validators/retriever_resource.py index a8bcd60abe..32725c7432 100644 --- a/api/core/apps/config_validators/retriever_resource.py +++ b/api/core/apps/config_validators/retriever_resource.py @@ -1,9 +1,8 @@ -from typing import Tuple class RetrieverResourceValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for retriever resource feature diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/apps/config_validators/speech_to_text.py index 577bef0e59..92a1b25ae6 100644 --- a/api/core/apps/config_validators/speech_to_text.py +++ b/api/core/apps/config_validators/speech_to_text.py @@ -1,9 +1,8 @@ -from typing import Tuple class SpeechToTextValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for speech to text feature diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/apps/config_validators/suggested_questions.py index 938b66bb6e..9161b31678 100644 --- a/api/core/apps/config_validators/suggested_questions.py +++ b/api/core/apps/config_validators/suggested_questions.py @@ -1,9 +1,8 @@ -from typing import Tuple class SuggestedQuestionsValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for suggested questions feature diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/apps/config_validators/text_to_speech.py index efe34a8a3e..182a912d52 100644 --- a/api/core/apps/config_validators/text_to_speech.py +++ b/api/core/apps/config_validators/text_to_speech.py @@ -1,9 +1,8 @@ -from typing import Tuple class TextToSpeechValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for text to speech feature diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/apps/config_validators/user_input_form.py index 7116c55afc..249d6745ae 100644 --- a/api/core/apps/config_validators/user_input_form.py +++ b/api/core/apps/config_validators/user_input_form.py @@ -1,10 +1,9 @@ import re -from typing import Tuple class UserInputFormValidator: @classmethod - def validate_and_set_defaults(cls, config: dict) -> Tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for user input form diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 6dd729694b..9acd62b997 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -9,7 +9,7 @@ from core.apps.config_validators.model import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db -from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message, AppMode +from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message from services.app_model_config_service import AppModelConfigService from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError From 3badc4423a6fb91642b2263c68cc4442d06a3787 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 12:22:30 +0800 Subject: [PATCH 082/200] fix: wrong default model parameters when creating app --- api/constants/model_template.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index ca0b754989..61aab64d8a 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -23,13 +23,7 @@ default_app_templates = { "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } }, @@ -46,13 +40,7 @@ default_app_templates = { "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } }, @@ -69,16 +57,8 @@ default_app_templates = { "provider": "openai", "name": "gpt-4", "mode": "chat", - "completion_params": { - "max_tokens": 512, - "temperature": 1, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } + "completion_params": {} } } - }, + } } - - From 896c20021156bd3877b844f122375e01c92ba4b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 13:24:26 +0800 Subject: [PATCH 083/200] fix import problem --- api/core/apps/config_validators/agent.py | 2 +- api/core/apps/config_validators/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py index c6584d2903..b445aedbf8 100644 --- a/api/core/apps/config_validators/agent.py +++ b/api/core/apps/config_validators/agent.py @@ -1,7 +1,7 @@ import uuid -from core.agent.agent_executor import PlanningStrategy from core.apps.config_validators.dataset import DatasetValidator +from core.entities.agent_entities import PlanningStrategy OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/apps/config_validators/dataset.py b/api/core/apps/config_validators/dataset.py index 9846f9085c..fb5b648320 100644 --- a/api/core/apps/config_validators/dataset.py +++ b/api/core/apps/config_validators/dataset.py @@ -1,6 +1,6 @@ import uuid -from core.agent.agent_executor import PlanningStrategy +from core.entities.agent_entities import PlanningStrategy from models.model import AppMode from services.dataset_service import DatasetService From 799db69e4f334a20cbbfad540b518bffc4b698d9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 17:33:52 +0800 Subject: [PATCH 084/200] refactor app --- api/controllers/console/app/completion.py | 6 +- api/controllers/console/app/generator.py | 2 +- api/controllers/console/explore/completion.py | 6 +- api/controllers/service_api/app/completion.py | 6 +- api/controllers/web/completion.py | 6 +- api/core/{app_runner => agent}/__init__.py | 0 .../base_agent_runner.py} | 8 +- .../cot_agent_runner.py} | 6 +- .../fc_agent_runner.py} | 6 +- api/core/{apps => app}/__init__.py | 0 .../advanced_chat}/__init__.py | 0 .../advanced_chat/config_validator.py} | 14 +- .../agent_chat}/__init__.py | 0 .../agent_chat/app_runner.py} | 19 +- api/core/app/agent_chat/config_validator.py | 162 +++++++ api/core/app/app_manager.py | 382 +++++++++++++++ .../app_orchestration_config_converter.py} | 434 +----------------- .../app_queue_manager.py} | 6 +- .../app_runner.py => app/base_app_runner.py} | 26 +- api/core/{features => app/chat}/__init__.py | 0 .../chat/app_runner.py} | 16 +- .../chat/config_validator.py} | 26 +- .../completion}/__init__.py | 0 api/core/app/completion/app_runner.py | 266 +++++++++++ .../completion/config_validator.py} | 20 +- .../agent => app/features}/__init__.py | 0 .../features/annotation_reply}/__init__.py | 0 .../annotation_reply}/annotation_reply.py | 0 .../features/hosting_moderation/__init__.py | 0 .../hosting_moderation}/hosting_moderation.py | 0 .../generate_task_pipeline.py | 12 +- api/core/app/validators/__init__.py | 0 .../validators/dataset_retrieval.py} | 0 .../validators/external_data_fetch.py} | 2 +- .../validators}/file_upload.py | 0 .../validators/model_validator.py} | 0 .../validators}/moderation.py | 0 .../validators}/more_like_this.py | 0 .../validators}/opening_statement.py | 0 .../validators}/prompt.py | 0 .../validators}/retriever_resource.py | 0 .../validators}/speech_to_text.py | 0 .../validators}/suggested_questions.py | 0 .../validators}/text_to_speech.py | 0 .../validators}/user_input_form.py | 0 api/core/app/workflow/__init__.py | 0 .../workflow/config_validator.py} | 6 +- .../app_config_validators/agent_chat_app.py | 82 ---- api/core/apps/config_validators/agent.py | 81 ---- .../agent_loop_gather_callback_handler.py | 4 +- .../index_tool_callback_handler.py | 4 +- .../external_data_fetch.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/llm_generator/__init__.py | 0 .../llm_generator.py | 8 +- .../llm_generator/output_parser/__init__.py | 0 .../output_parser/rule_config_generator.py | 2 +- .../suggested_questions_after_answer.py | 2 +- api/core/{prompt => llm_generator}/prompts.py | 0 .../input_moderation.py} | 2 +- .../output_moderation.py} | 4 +- api/core/prompt/__init__.py | 0 api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/prompt_templates/__init__.py | 0 .../advanced_prompt_templates.py | 0 .../baichuan_chat.json | 0 .../baichuan_completion.json | 0 .../common_chat.json | 0 .../common_completion.json | 0 api/core/prompt/simple_prompt_transform.py | 4 +- api/core/prompt/utils/__init__.py | 0 .../prompt_template_parser.py} | 0 .../processor/qa_index_processor.py | 2 +- api/core/rag/retrieval/__init__.py | 0 api/core/rag/retrieval/agent/__init__.py | 0 .../retrieval}/agent/agent_llm_callback.py | 0 .../retrieval}/agent/fake_llm.py | 0 .../retrieval}/agent/llm_chain.py | 4 +- .../agent/multi_dataset_router_agent.py | 2 +- .../retrieval/agent/output_parser/__init__.py | 0 .../agent/output_parser/structured_chat.py | 0 .../structed_multi_dataset_router_agent.py | 2 +- .../agent_based_dataset_executor.py | 8 +- .../retrieval}/dataset_retrieval.py | 4 +- api/core/tools/tool/dataset_retriever_tool.py | 4 +- ...rsation_name_when_first_message_created.py | 2 +- api/models/model.py | 18 +- .../advanced_prompt_template_service.py | 2 +- api/services/app_model_config_service.py | 10 +- api/services/completion_service.py | 8 +- api/services/conversation_service.py | 2 +- api/services/message_service.py | 2 +- api/services/workflow/workflow_converter.py | 4 +- .../prompt/test_advanced_prompt_transform.py | 2 +- 94 files changed, 991 insertions(+), 721 deletions(-) rename api/core/{app_runner => agent}/__init__.py (100%) rename api/core/{features/assistant_base_runner.py => agent/base_agent_runner.py} (99%) rename api/core/{features/assistant_cot_runner.py => agent/cot_agent_runner.py} (99%) rename api/core/{features/assistant_fc_runner.py => agent/fc_agent_runner.py} (98%) rename api/core/{apps => app}/__init__.py (100%) rename api/core/{apps/app_config_validators => app/advanced_chat}/__init__.py (100%) rename api/core/{apps/app_config_validators/advanced_chat_app.py => app/advanced_chat/config_validator.py} (77%) rename api/core/{apps/config_validators => app/agent_chat}/__init__.py (100%) rename api/core/{app_runner/assistant_app_runner.py => app/agent_chat/app_runner.py} (95%) create mode 100644 api/core/app/agent_chat/config_validator.py create mode 100644 api/core/app/app_manager.py rename api/core/{application_manager.py => app/app_orchestration_config_converter.py} (52%) rename api/core/{application_queue_manager.py => app/app_queue_manager.py} (97%) rename api/core/{app_runner/app_runner.py => app/base_app_runner.py} (94%) rename api/core/{features => app/chat}/__init__.py (100%) rename api/core/{app_runner/basic_app_runner.py => app/chat/app_runner.py} (95%) rename api/core/{apps/app_config_validators/chat_app.py => app/chat/config_validator.py} (75%) rename api/core/{features/dataset_retrieval => app/completion}/__init__.py (100%) create mode 100644 api/core/app/completion/app_runner.py rename api/core/{apps/app_config_validators/completion_app.py => app/completion/config_validator.py} (76%) rename api/core/{features/dataset_retrieval/agent => app/features}/__init__.py (100%) rename api/core/{features/dataset_retrieval/agent/output_parser => app/features/annotation_reply}/__init__.py (100%) rename api/core/{features => app/features/annotation_reply}/annotation_reply.py (100%) create mode 100644 api/core/app/features/hosting_moderation/__init__.py rename api/core/{features => app/features/hosting_moderation}/hosting_moderation.py (100%) rename api/core/{app_runner => app}/generate_task_pipeline.py (98%) create mode 100644 api/core/app/validators/__init__.py rename api/core/{apps/config_validators/dataset.py => app/validators/dataset_retrieval.py} (100%) rename api/core/{apps/config_validators/external_data_tools.py => app/validators/external_data_fetch.py} (97%) rename api/core/{apps/config_validators => app/validators}/file_upload.py (100%) rename api/core/{apps/config_validators/model.py => app/validators/model_validator.py} (100%) rename api/core/{apps/config_validators => app/validators}/moderation.py (100%) rename api/core/{apps/config_validators => app/validators}/more_like_this.py (100%) rename api/core/{apps/config_validators => app/validators}/opening_statement.py (100%) rename api/core/{apps/config_validators => app/validators}/prompt.py (100%) rename api/core/{apps/config_validators => app/validators}/retriever_resource.py (100%) rename api/core/{apps/config_validators => app/validators}/speech_to_text.py (100%) rename api/core/{apps/config_validators => app/validators}/suggested_questions.py (100%) rename api/core/{apps/config_validators => app/validators}/text_to_speech.py (100%) rename api/core/{apps/config_validators => app/validators}/user_input_form.py (100%) create mode 100644 api/core/app/workflow/__init__.py rename api/core/{apps/app_config_validators/workflow_app.py => app/workflow/config_validator.py} (83%) delete mode 100644 api/core/apps/app_config_validators/agent_chat_app.py delete mode 100644 api/core/apps/config_validators/agent.py rename api/core/{features => external_data_tool}/external_data_fetch.py (98%) create mode 100644 api/core/llm_generator/__init__.py rename api/core/{generator => llm_generator}/llm_generator.py (93%) create mode 100644 api/core/llm_generator/output_parser/__init__.py rename api/core/{prompt => llm_generator}/output_parser/rule_config_generator.py (94%) rename api/core/{prompt => llm_generator}/output_parser/suggested_questions_after_answer.py (87%) rename api/core/{prompt => llm_generator}/prompts.py (100%) rename api/core/{features/moderation.py => moderation/input_moderation.py} (98%) rename api/core/{app_runner/moderation_handler.py => moderation/output_moderation.py} (97%) create mode 100644 api/core/prompt/__init__.py create mode 100644 api/core/prompt/prompt_templates/__init__.py rename api/core/prompt/{ => prompt_templates}/advanced_prompt_templates.py (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/baichuan_completion.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_chat.json (100%) rename api/core/prompt/{generate_prompts => prompt_templates}/common_completion.json (100%) create mode 100644 api/core/prompt/utils/__init__.py rename api/core/prompt/{prompt_template.py => utils/prompt_template_parser.py} (100%) create mode 100644 api/core/rag/retrieval/__init__.py create mode 100644 api/core/rag/retrieval/agent/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/agent_llm_callback.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/fake_llm.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/llm_chain.py (91%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/multi_dataset_router_agent.py (98%) create mode 100644 api/core/rag/retrieval/agent/output_parser/__init__.py rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/output_parser/structured_chat.py (100%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent/structed_multi_dataset_router_agent.py (99%) rename api/core/{features/dataset_retrieval => rag/retrieval}/agent_based_dataset_executor.py (92%) rename api/core/{features/dataset_retrieval => rag/retrieval}/dataset_retrieval.py (98%) diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index e62475308f..0632c0439b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 @@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource): def post(self, app_model, task_id): account = flask_login.current_user - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) return {'result': 'success'}, 200 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 3ec932b5f1..ee02fc1846 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -11,7 +11,7 @@ from controllers.console.app.error import ( from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 6406d5b3b0..22ea4bbac2 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 @@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c6cfb24378..fd4ce831b3 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,7 +19,7 @@ from controllers.service_api.app.error import ( ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -85,7 +85,7 @@ class CompletionStopApi(Resource): if app_model.mode != 'completion': raise AppUnavailableError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 @@ -147,7 +147,7 @@ class ChatStopApi(Resource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 61d4f8c362..fd94ec7646 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.application_queue_manager import ApplicationQueueManager +from core.app.app_queue_manager import AppQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource): if app_model.mode != 'completion': raise NotCompletionAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 @@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource): if app_model.mode != 'chat': raise NotChatAppError() - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) + AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {'result': 'success'}, 200 diff --git a/api/core/app_runner/__init__.py b/api/core/agent/__init__.py similarity index 100% rename from api/core/app_runner/__init__.py rename to api/core/agent/__init__.py diff --git a/api/core/features/assistant_base_runner.py b/api/core/agent/base_agent_runner.py similarity index 99% rename from api/core/features/assistant_base_runner.py rename to api/core/agent/base_agent_runner.py index 1d9541070f..0658124d14 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,8 +5,8 @@ from datetime import datetime from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( @@ -48,13 +48,13 @@ from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) -class BaseAssistantApplicationRunner(AppRunner): +class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: ApplicationGenerateEntity, app_orchestration_config: AppOrchestrationConfigEntity, model_config: ModelConfigEntity, config: AgentEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, diff --git a/api/core/features/assistant_cot_runner.py b/api/core/agent/cot_agent_runner.py similarity index 99% rename from api/core/features/assistant_cot_runner.py rename to api/core/agent/cot_agent_runner.py index 3762ddcf62..152e445795 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,9 @@ import re from collections.abc import Generator from typing import Literal, Union -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -262,7 +262,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): tool_call_args = json.loads(tool_call_args) except json.JSONDecodeError: pass - + tool_response = tool_instance.invoke( user_id=self.user_id, tool_parameters=tool_call_args diff --git a/api/core/features/assistant_fc_runner.py b/api/core/agent/fc_agent_runner.py similarity index 98% rename from api/core/features/assistant_fc_runner.py rename to api/core/agent/fc_agent_runner.py index 391e040c53..0cf0d3762c 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,8 @@ import logging from collections.abc import Generator from typing import Any, Union -from core.application_queue_manager import PublishFrom -from core.features.assistant_base_runner import BaseAssistantApplicationRunner +from core.app.app_queue_manager import PublishFrom +from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): +class FunctionCallAgentRunner(BaseAgentRunner): def run(self, conversation: Conversation, message: Message, query: str, diff --git a/api/core/apps/__init__.py b/api/core/app/__init__.py similarity index 100% rename from api/core/apps/__init__.py rename to api/core/app/__init__.py diff --git a/api/core/apps/app_config_validators/__init__.py b/api/core/app/advanced_chat/__init__.py similarity index 100% rename from api/core/apps/app_config_validators/__init__.py rename to api/core/app/advanced_chat/__init__.py diff --git a/api/core/apps/app_config_validators/advanced_chat_app.py b/api/core/app/advanced_chat/config_validator.py similarity index 77% rename from api/core/apps/app_config_validators/advanced_chat_app.py rename to api/core/app/advanced_chat/config_validator.py index dc7664b844..39c00c028e 100644 --- a/api/core/apps/app_config_validators/advanced_chat_app.py +++ b/api/core/app/advanced_chat/config_validator.py @@ -1,10 +1,10 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class AdvancedChatAppConfigValidator: diff --git a/api/core/apps/config_validators/__init__.py b/api/core/app/agent_chat/__init__.py similarity index 100% rename from api/core/apps/config_validators/__init__.py rename to api/core/app/agent_chat/__init__.py diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app/agent_chat/app_runner.py similarity index 95% rename from api/core/app_runner/assistant_app_runner.py rename to api/core/app/agent_chat/app_runner.py index 655a5a1c7c..b046e935a5 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app/agent_chat/app_runner.py @@ -1,11 +1,11 @@ import logging from typing import cast -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity -from core.features.assistant_cot_runner import AssistantCotApplicationRunner -from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -19,12 +19,13 @@ from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) -class AssistantApplicationRunner(AppRunner): + +class AgentChatAppRunner(AppRunner): """ - Assistant Application Runner + Agent Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -201,7 +202,7 @@ class AssistantApplicationRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = AssistantCotApplicationRunner( + assistant_cot_runner = CotAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, @@ -223,7 +224,7 @@ class AssistantApplicationRunner(AppRunner): inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = AssistantFunctionCallApplicationRunner( + assistant_fc_runner = FunctionCallAgentRunner( tenant_id=application_generate_entity.tenant_id, application_generate_entity=application_generate_entity, app_orchestration_config=app_orchestration_config, diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/agent_chat/config_validator.py new file mode 100644 index 0000000000..6596b19f99 --- /dev/null +++ b/api/core/app/agent_chat/config_validator.py @@ -0,0 +1,162 @@ +import uuid + +from core.entities.agent_entities import PlanningStrategy +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator +from models.model import AppMode + + +OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] + + +class AgentChatAppConfigValidator: + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for agent chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.AGENT_CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # external data tools validation + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # agent_mode + config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + + @classmethod + def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate agent_mode and set defaults for agent feature + + :param tenant_id: tenant ID + :param config: app model config args + """ + if not config.get("agent_mode"): + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if not config["agent_mode"].get("strategy"): + config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + + if config["agent_mode"]["strategy"] not in [member.value for member in + list(PlanningStrategy.__members__.values())]: + raise ValueError("strategy in agent_mode must be in the specified strategy list") + + if not config["agent_mode"].get("tools"): + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key in OLD_TOOLS: + # old style, use tool name as key + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") + + return config, ["agent_mode"] diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py new file mode 100644 index 0000000000..0819ed864b --- /dev/null +++ b/api/core/app/app_manager.py @@ -0,0 +1,382 @@ +import json +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Optional, Union, cast + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.agent_chat.app_runner import AgentChatAppRunner +from core.app.chat.app_runner import ChatAppRunner +from core.app.generate_task_pipeline import GenerateTaskPipeline +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.entities.application_entities import ( + ApplicationGenerateEntity, + InvokeFrom, +) +from core.file.file_obj import FileObj +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message, MessageFile + +logger = logging.getLogger(__name__) + + +class AppManager: + """ + This class is responsible for managing application + """ + + def generate(self, tenant_id: str, + app_id: str, + app_model_config_id: str, + app_model_config_dict: dict, + app_model_config_override: bool, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + inputs: dict[str, str], + query: Optional[str] = None, + files: Optional[list[FileObj]] = None, + conversation: Optional[Conversation] = None, + stream: bool = False, + extras: Optional[dict[str, Any]] = None) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param tenant_id: workspace ID + :param app_id: app ID + :param app_model_config_id: app model config id + :param app_model_config_dict: app model config dict + :param app_model_config_override: app model config override + :param user: account or end user + :param invoke_from: invoke from source + :param inputs: inputs + :param query: query + :param files: file obj list + :param conversation: conversation + :param stream: is stream + :param extras: extras + """ + # init task id + task_id = str(uuid.uuid4()) + + # init application generate entity + application_generate_entity = ApplicationGenerateEntity( + task_id=task_id, + tenant_id=tenant_id, + app_id=app_id, + app_model_config_id=app_model_config_id, + app_model_config_dict=app_model_config_dict, + app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict( + tenant_id=tenant_id, + app_model_config_dict=app_model_config_dict + ), + app_model_config_override=app_model_config_override, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else inputs, + query=query.replace('\x00', '') if query else None, + files=files if files else [], + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + if not stream and application_generate_entity.app_orchestration_config_entity.agent: + raise ValueError("Agent app is not supported in blocking mode.") + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + if application_generate_entity.app_orchestration_config_entity.agent: + # agent app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + else: + # basic app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = GenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() + + def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + + model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_schema = model_type_instance.get_model_schema( + model=app_orchestration_config_entity.model_config.model, + credentials=app_orchestration_config_entity.model_config.credentials + ) + + app_record = (db.session.query(App) + .filter(App.id == application_generate_entity.app_id).first()) + + app_mode = app_record.mode + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + override_model_configs = None + if application_generate_entity.app_model_config_override: + override_model_configs = application_generate_entity.app_model_config_dict + + introduction = '' + if app_mode == 'chat': + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not application_generate_entity.conversation_id: + conversation = Conversation( + app_id=app_record.id, + app_model_config_id=application_generate_entity.app_model_config_id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_mode, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + else: + conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == application_generate_entity.conversation_id, + Conversation.app_id == app_record.id + ).first() + ) + + currency = model_schema.pricing.currency if model_schema.pricing else 'USD' + + message = Message( + app_id=app_record.id, + model_provider=app_orchestration_config_entity.model_config.provider, + model_id=app_orchestration_config_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency=currency, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + agent_based=app_orchestration_config_entity.agent is not None + ) + + db.session.add(message) + db.session.commit() + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity + introduction = app_orchestration_config_entity.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/application_manager.py b/api/core/app/app_orchestration_config_converter.py similarity index 52% rename from api/core/application_manager.py rename to api/core/app/app_orchestration_config_converter.py index ea0c85427d..ddf49949a3 100644 --- a/api/core/application_manager.py +++ b/api/core/app/app_orchestration_config_converter.py @@ -1,241 +1,21 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import cast -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app_runner.assistant_app_runner import AssistantApplicationRunner -from core.app_runner.basic_app_runner import BasicApplicationRunner -from core.app_runner.generate_task_pipeline import GenerateTaskPipeline -from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, - VariableEntity, -) +from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \ + TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \ + ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \ + AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileObj +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager from core.tools.prompt.template import REACT_PROMPT_TEMPLATES -from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) -class ApplicationManager: - """ - This class is responsible for managing application - """ - - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # init application generate entity - application_generate_entity = ApplicationGenerateEntity( - task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, - app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=self.convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_orchestration_config_entity.agent: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = ApplicationQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_orchestration_config_entity.agent: - # agent app - runner = AssistantApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - else: - # basic app - runner = BasicApplicationRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() - - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - - def convert_from_app_model_config_dict(self, tenant_id: str, +class AppOrchestrationConfigConverter: + @classmethod + def convert_from_app_model_config_dict(cls, tenant_id: str, app_model_config_dict: dict, skip_check: bool = False) \ -> AppOrchestrationConfigEntity: @@ -394,7 +174,7 @@ class ApplicationManager: ) properties['variables'] = [] - + # variables and external_data_tools for variable in copy_app_model_config_dict.get('user_input_form', []): typ = list(variable.keys())[0] @@ -444,7 +224,7 @@ class ApplicationManager: show_retrieve_source = True properties['show_retrieve_source'] = show_retrieve_source - + dataset_ids = [] if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { @@ -452,26 +232,23 @@ class ApplicationManager: 'datasets': [] }) - for dataset in datasets.get('datasets', []): keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != 'dataset': continue dataset = dataset['dataset'] - + if 'enabled' not in dataset or not dataset['enabled']: continue - + dataset_id = dataset.get('id', None) if dataset_id: dataset_ids.append(dataset_id) - else: - datasets = {'strategy': 'router', 'datasets': []} if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ and 'enabled' in copy_app_model_config_dict['agent_mode'] \ and copy_app_model_config_dict['agent_mode']['enabled']: - + agent_dict = copy_app_model_config_dict.get('agent_mode', {}) agent_strategy = agent_dict.get('strategy', 'cot') @@ -515,7 +292,7 @@ class ApplicationManager: dataset_id = tool_item['id'] dataset_ids.append(dataset_id) - + if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: agent_prompt = agent_dict.get('prompt', None) or {} @@ -523,13 +300,18 @@ class ApplicationManager: model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') if model_mode == 'completion': agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), ) properties['agent'] = AgentEntity( @@ -551,7 +333,7 @@ class ApplicationManager: dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] ) ) @@ -624,169 +406,3 @@ class ApplicationManager: ) return AppOrchestrationConfigEntity(**properties) - - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials - ) - - app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - db.session.refresh(conversation) - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None - ) - - db.session.add(message) - db.session.commit() - db.session.refresh(message) - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/application_queue_manager.py b/api/core/app/app_queue_manager.py similarity index 97% rename from api/core/application_queue_manager.py rename to api/core/app/app_queue_manager.py index 9590a1e726..c09cae3245 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -32,7 +32,7 @@ class PublishFrom(Enum): TASK_PIPELINE = 2 -class ApplicationQueueManager: +class AppQueueManager: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, @@ -50,7 +50,7 @@ class ApplicationQueueManager: self._message_id = str(message_id) user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -239,7 +239,7 @@ class ApplicationQueueManager: Check if task is stopped :return: """ - stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) result = redis_client.get(stopped_cache_key) if result is not None: return True diff --git a/api/core/app_runner/app_runner.py b/api/core/app/base_app_runner.py similarity index 94% rename from api/core/app_runner/app_runner.py rename to api/core/app/base_app_runner.py index 95f2f568dc..788e3f91a3 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app/base_app_runner.py @@ -2,7 +2,7 @@ import time from collections.abc import Generator from typing import Optional, Union, cast -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ( ApplicationGenerateEntity, AppOrchestrationConfigEntity, @@ -11,10 +11,10 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, ) -from core.features.annotation_reply import AnnotationReplyFeature -from core.features.external_data_fetch import ExternalDataFetchFeature -from core.features.hosting_moderation import HostingModerationFeature -from core.features.moderation import ModerationFeature +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature +from core.moderation.input_moderation import InputModeration from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -169,7 +169,7 @@ class AppRunner: return prompt_messages, stop - def direct_output(self, queue_manager: ApplicationQueueManager, + def direct_output(self, queue_manager: AppQueueManager, app_orchestration_config: AppOrchestrationConfigEntity, prompt_messages: list, text: str, @@ -210,7 +210,7 @@ class AppRunner: ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, stream: bool, agent: bool = False) -> None: """ @@ -234,7 +234,7 @@ class AppRunner: ) def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result direct @@ -248,7 +248,7 @@ class AppRunner: ) def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, agent: bool) -> None: """ Handle invoke result @@ -306,7 +306,7 @@ class AppRunner: :param query: query :return: """ - moderation_feature = ModerationFeature() + moderation_feature = InputModeration() return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, @@ -316,7 +316,7 @@ class AppRunner: ) def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -358,7 +358,7 @@ class AppRunner: :param query: the query :return: the filled inputs """ - external_data_fetch_feature = ExternalDataFetchFeature() + external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( tenant_id=tenant_id, app_id=app_id, @@ -388,4 +388,4 @@ class AppRunner: query=query, user_id=user_id, invoke_from=invoke_from - ) \ No newline at end of file + ) diff --git a/api/core/features/__init__.py b/api/core/app/chat/__init__.py similarity index 100% rename from api/core/features/__init__.py rename to api/core/app/chat/__init__.py diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app/chat/app_runner.py similarity index 95% rename from api/core/app_runner/basic_app_runner.py rename to api/core/app/chat/app_runner.py index 0e0fe6e3bf..a1613e37a2 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,7 +10,7 @@ from core.entities.application_entities import ( InvokeFrom, ModelConfigEntity, ) -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException @@ -20,13 +20,13 @@ from models.model import App, AppMode, Conversation, Message logger = logging.getLogger(__name__) -class BasicApplicationRunner(AppRunner): +class ChatAppRunner(AppRunner): """ - Basic Application Runner + Chat Application Runner """ def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -215,7 +215,7 @@ class BasicApplicationRunner(AppRunner): def retrieve_dataset_context(self, tenant_id: str, app_record: App, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, model_config: ModelConfigEntity, dataset_config: DatasetEntity, show_retrieve_source: bool, @@ -254,7 +254,7 @@ class BasicApplicationRunner(AppRunner): and dataset_config.retrieve_config.query_variable): query = inputs.get(dataset_config.retrieve_config.query_variable, "") - dataset_retrieval = DatasetRetrievalFeature() + dataset_retrieval = DatasetRetrieval() return dataset_retrieval.retrieve( tenant_id=tenant_id, model_config=model_config, diff --git a/api/core/apps/app_config_validators/chat_app.py b/api/core/app/chat/config_validator.py similarity index 75% rename from api/core/apps/app_config_validators/chat_app.py rename to api/core/app/chat/config_validator.py index 83c792e610..adb8408e28 100644 --- a/api/core/apps/app_config_validators/chat_app.py +++ b/api/core/app/chat/config_validator.py @@ -1,15 +1,15 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.opening_statement import OpeningStatementValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.retriever_resource import RetrieverResourceValidator +from core.app.validators.speech_to_text import SpeechToTextValidator +from core.app.validators.suggested_questions import SuggestedQuestionsValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -35,7 +35,7 @@ class ChatAppConfigValidator: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/__init__.py b/api/core/app/completion/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/__init__.py rename to api/core/app/completion/__init__.py diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py new file mode 100644 index 0000000000..34c6a5156f --- /dev/null +++ b/api/core/app/completion/app_runner.py @@ -0,0 +1,266 @@ +import logging +from typing import Optional + +from core.app.base_app_runner import AppRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import ( + ApplicationGenerateEntity, + DatasetEntity, + InvokeFrom, + ModelConfigEntity, +) +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.moderation.base import ModerationException +from extensions.ext_database import db +from models.model import App, AppMode, Conversation, Message + +logger = logging.getLogger(__name__) + + +class CompletionAppRunner(AppRunner): + """ + Completion Application Runner + """ + + def run(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + if not app_record: + raise ValueError("App not found") + + app_orchestration_config = application_generate_entity.app_orchestration_config_entity + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=application_generate_entity.tenant_id, + app_orchestration_config_entity=app_orchestration_config, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_orchestration_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + # get context from datasets + context = None + if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + context = self.retrieve_dataset_context( + tenant_id=app_record.tenant_id, + app_record=app_record, + queue_manager=queue_manager, + model_config=app_orchestration_config.model_config, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + dataset_config=app_orchestration_config.dataset, + message=message, + inputs=inputs, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + memory=memory + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + context=context, + memory=memory + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recale_llm_max_tokens( + model_config=app_orchestration_config.model_config, + prompt_messages=prompt_messages + ) + + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_orchestration_config.model_config.parameters, + stop=stop, + stream=application_generate_entity.stream, + user=application_generate_entity.user_id, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream + ) + + def retrieve_dataset_context(self, tenant_id: str, + app_record: App, + queue_manager: AppQueueManager, + model_config: ModelConfigEntity, + dataset_config: DatasetEntity, + show_retrieve_source: bool, + message: Message, + inputs: dict, + query: str, + user_id: str, + invoke_from: InvokeFrom, + memory: Optional[TokenBufferMemory] = None) -> Optional[str]: + """ + Retrieve dataset context + :param tenant_id: tenant id + :param app_record: app record + :param queue_manager: queue manager + :param model_config: model config + :param dataset_config: dataset config + :param show_retrieve_source: show retrieve source + :param message: message + :param inputs: inputs + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :param memory: memory + :return: + """ + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + user_id, + invoke_from + ) + + # TODO + if (app_record.mode == AppMode.COMPLETION.value and dataset_config + and dataset_config.retrieve_config.query_variable): + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + return dataset_retrieval.retrieve( + tenant_id=tenant_id, + model_config=model_config, + config=dataset_config, + query=query, + invoke_from=invoke_from, + show_retrieve_source=show_retrieve_source, + hit_callback=hit_callback, + memory=memory + ) + \ No newline at end of file diff --git a/api/core/apps/app_config_validators/completion_app.py b/api/core/app/completion/config_validator.py similarity index 76% rename from api/core/apps/app_config_validators/completion_app.py rename to api/core/app/completion/config_validator.py index 00371f8d05..7cc35efd64 100644 --- a/api/core/apps/app_config_validators/completion_app.py +++ b/api/core/app/completion/config_validator.py @@ -1,12 +1,12 @@ -from core.apps.config_validators.dataset import DatasetValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.more_like_this import MoreLikeThisValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator +from core.app.validators.dataset_retrieval import DatasetValidator +from core.app.validators.external_data_fetch import ExternalDataFetchValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.model_validator import ModelValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.more_like_this import MoreLikeThisValidator +from core.app.validators.prompt import PromptValidator +from core.app.validators.text_to_speech import TextToSpeechValidator +from core.app.validators.user_input_form import UserInputFormValidator from models.model import AppMode @@ -32,7 +32,7 @@ class CompletionAppConfigValidator: related_config_keys.extend(current_related_config_keys) # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation diff --git a/api/core/features/dataset_retrieval/agent/__init__.py b/api/core/app/features/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/__init__.py rename to api/core/app/features/__init__.py diff --git a/api/core/features/dataset_retrieval/agent/output_parser/__init__.py b/api/core/app/features/annotation_reply/__init__.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/__init__.py rename to api/core/app/features/annotation_reply/__init__.py diff --git a/api/core/features/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py similarity index 100% rename from api/core/features/annotation_reply.py rename to api/core/app/features/annotation_reply/annotation_reply.py diff --git a/api/core/app/features/hosting_moderation/__init__.py b/api/core/app/features/hosting_moderation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py similarity index 100% rename from api/core/features/hosting_moderation.py rename to api/core/app/features/hosting_moderation/hosting_moderation.py diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py similarity index 98% rename from api/core/app_runner/generate_task_pipeline.py rename to api/core/app/generate_task_pipeline.py index 1cc56483ad..6d52fa7348 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -6,8 +6,8 @@ from typing import Optional, Union, cast from pydantic import BaseModel -from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( AnnotationReplyEvent, @@ -35,7 +35,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db @@ -59,7 +59,7 @@ class GenerateTaskPipeline: """ def __init__(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: """ @@ -633,7 +633,7 @@ class GenerateTaskPipeline: return prompts - def _init_output_moderation(self) -> Optional[OutputModerationHandler]: + def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. :return: @@ -642,7 +642,7 @@ class GenerateTaskPipeline: sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance if sensitive_word_avoidance: - return OutputModerationHandler( + return OutputModeration( tenant_id=self._application_generate_entity.tenant_id, app_id=self._application_generate_entity.app_id, rule=ModerationRule( diff --git a/api/core/app/validators/__init__.py b/api/core/app/validators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/config_validators/dataset.py b/api/core/app/validators/dataset_retrieval.py similarity index 100% rename from api/core/apps/config_validators/dataset.py rename to api/core/app/validators/dataset_retrieval.py diff --git a/api/core/apps/config_validators/external_data_tools.py b/api/core/app/validators/external_data_fetch.py similarity index 97% rename from api/core/apps/config_validators/external_data_tools.py rename to api/core/app/validators/external_data_fetch.py index 02ecc8d715..5910aa17e7 100644 --- a/api/core/apps/config_validators/external_data_tools.py +++ b/api/core/app/validators/external_data_fetch.py @@ -2,7 +2,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory -class ExternalDataToolsValidator: +class ExternalDataFetchValidator: @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/apps/config_validators/file_upload.py b/api/core/app/validators/file_upload.py similarity index 100% rename from api/core/apps/config_validators/file_upload.py rename to api/core/app/validators/file_upload.py diff --git a/api/core/apps/config_validators/model.py b/api/core/app/validators/model_validator.py similarity index 100% rename from api/core/apps/config_validators/model.py rename to api/core/app/validators/model_validator.py diff --git a/api/core/apps/config_validators/moderation.py b/api/core/app/validators/moderation.py similarity index 100% rename from api/core/apps/config_validators/moderation.py rename to api/core/app/validators/moderation.py diff --git a/api/core/apps/config_validators/more_like_this.py b/api/core/app/validators/more_like_this.py similarity index 100% rename from api/core/apps/config_validators/more_like_this.py rename to api/core/app/validators/more_like_this.py diff --git a/api/core/apps/config_validators/opening_statement.py b/api/core/app/validators/opening_statement.py similarity index 100% rename from api/core/apps/config_validators/opening_statement.py rename to api/core/app/validators/opening_statement.py diff --git a/api/core/apps/config_validators/prompt.py b/api/core/app/validators/prompt.py similarity index 100% rename from api/core/apps/config_validators/prompt.py rename to api/core/app/validators/prompt.py diff --git a/api/core/apps/config_validators/retriever_resource.py b/api/core/app/validators/retriever_resource.py similarity index 100% rename from api/core/apps/config_validators/retriever_resource.py rename to api/core/app/validators/retriever_resource.py diff --git a/api/core/apps/config_validators/speech_to_text.py b/api/core/app/validators/speech_to_text.py similarity index 100% rename from api/core/apps/config_validators/speech_to_text.py rename to api/core/app/validators/speech_to_text.py diff --git a/api/core/apps/config_validators/suggested_questions.py b/api/core/app/validators/suggested_questions.py similarity index 100% rename from api/core/apps/config_validators/suggested_questions.py rename to api/core/app/validators/suggested_questions.py diff --git a/api/core/apps/config_validators/text_to_speech.py b/api/core/app/validators/text_to_speech.py similarity index 100% rename from api/core/apps/config_validators/text_to_speech.py rename to api/core/app/validators/text_to_speech.py diff --git a/api/core/apps/config_validators/user_input_form.py b/api/core/app/validators/user_input_form.py similarity index 100% rename from api/core/apps/config_validators/user_input_form.py rename to api/core/app/validators/user_input_form.py diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/apps/app_config_validators/workflow_app.py b/api/core/app/workflow/config_validator.py similarity index 83% rename from api/core/apps/app_config_validators/workflow_app.py rename to api/core/app/workflow/config_validator.py index 545d3d79a3..b76eabaeb5 100644 --- a/api/core/apps/app_config_validators/workflow_app.py +++ b/api/core/app/workflow/config_validator.py @@ -1,6 +1,6 @@ -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator +from core.app.validators.file_upload import FileUploadValidator +from core.app.validators.moderation import ModerationValidator +from core.app.validators.text_to_speech import TextToSpeechValidator class WorkflowAppConfigValidator: diff --git a/api/core/apps/app_config_validators/agent_chat_app.py b/api/core/apps/app_config_validators/agent_chat_app.py deleted file mode 100644 index d507fae685..0000000000 --- a/api/core/apps/app_config_validators/agent_chat_app.py +++ /dev/null @@ -1,82 +0,0 @@ -from core.apps.config_validators.agent import AgentValidator -from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator -from core.apps.config_validators.file_upload import FileUploadValidator -from core.apps.config_validators.model import ModelValidator -from core.apps.config_validators.moderation import ModerationValidator -from core.apps.config_validators.opening_statement import OpeningStatementValidator -from core.apps.config_validators.prompt import PromptValidator -from core.apps.config_validators.retriever_resource import RetrieverResourceValidator -from core.apps.config_validators.speech_to_text import SpeechToTextValidator -from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator -from core.apps.config_validators.text_to_speech import TextToSpeechValidator -from core.apps.config_validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class AgentChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for agent chat app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.AGENT_CHAT - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # agent_mode - config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/apps/config_validators/agent.py b/api/core/apps/config_validators/agent.py deleted file mode 100644 index b445aedbf8..0000000000 --- a/api/core/apps/config_validators/agent.py +++ /dev/null @@ -1,81 +0,0 @@ -import uuid - -from core.apps.config_validators.dataset import DatasetValidator -from core.entities.agent_entities import PlanningStrategy - -OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] - - -class AgentValidator: - @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for agent feature - - :param tenant_id: tenant ID - :param config: app model config args - """ - if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } - - if not isinstance(config["agent_mode"], dict): - raise ValueError("agent_mode must be of object type") - - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False - - if not isinstance(config["agent_mode"]["enabled"], bool): - raise ValueError("enabled in agent_mode must be of boolean type") - - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - - if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: - raise ValueError("strategy in agent_mode must be in the specified strategy list") - - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] - - if not isinstance(config["agent_mode"]["tools"], list): - raise ValueError("tools in agent_mode must be a list of objects") - - for tool in config["agent_mode"]["tools"]: - key = list(tool.keys())[0] - if key in OLD_TOOLS: - # old style, use tool name as key - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False - - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") - - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") - - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") - else: - # latest style, use key-value pair - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - if "provider_type" not in tool: - raise ValueError("provider_type is required in agent_mode.tools") - if "provider_id" not in tool: - raise ValueError("provider_id is required in agent_mode.tools") - if "tool_name" not in tool: - raise ValueError("tool_name is required in agent_mode.tools") - if "tool_parameters" not in tool: - raise ValueError("tool_parameters is required in agent_mode.tools") - - return config, ["agent_mode"] diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index 1d25b8ab69..8a340a8b81 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -7,7 +7,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult @@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): raise_error: bool = True def __init__(self, model_config: ModelConfigEntity, - queue_manager: ApplicationQueueManager, + queue_manager: AppQueueManager, message: Message, message_chain: MessageChain) -> None: """Initialize callback handler.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 879c9df69d..e49a09d4c4 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,5 +1,5 @@ -from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from core.rag.models.document import Document from extensions.ext_database import db @@ -10,7 +10,7 @@ from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: ApplicationQueueManager, + def __init__(self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, diff --git a/api/core/features/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py similarity index 98% rename from api/core/features/external_data_fetch.py rename to api/core/external_data_tool/external_data_fetch.py index ef37f05528..64c7d1e859 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -11,7 +11,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory logger = logging.getLogger(__name__) -class ExternalDataFetchFeature: +class ExternalDataFetch: def fetch(self, tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index dd46aa27dc..01a8ea3a5d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -13,7 +13,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError from core.docstore.dataset_docstore import DatasetDocumentStore from core.errors.error import ProviderTokenNotInitError -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/llm_generator/__init__.py b/api/core/llm_generator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/generator/llm_generator.py b/api/core/llm_generator/llm_generator.py similarity index 93% rename from api/core/generator/llm_generator.py rename to api/core/llm_generator/llm_generator.py index 072b02dc94..6ce70df703 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -7,10 +7,10 @@ from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import PromptTemplateParser -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/__init__.py b/api/core/llm_generator/output_parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py similarity index 94% rename from api/core/prompt/output_parser/rule_config_generator.py rename to api/core/llm_generator/output_parser/rule_config_generator.py index 619555ce2e..b95653f69c 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -2,7 +2,7 @@ from typing import Any from langchain.schema import BaseOutputParser, OutputParserException -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py similarity index 87% rename from api/core/prompt/output_parser/suggested_questions_after_answer.py rename to api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e37142ec91..ad30bcfa07 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -4,7 +4,7 @@ from typing import Any from langchain.schema import BaseOutputParser -from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompts.py b/api/core/llm_generator/prompts.py similarity index 100% rename from api/core/prompt/prompts.py rename to api/core/llm_generator/prompts.py diff --git a/api/core/features/moderation.py b/api/core/moderation/input_moderation.py similarity index 98% rename from api/core/features/moderation.py rename to api/core/moderation/input_moderation.py index a9d65f56e8..2129c58d8d 100644 --- a/api/core/features/moderation.py +++ b/api/core/moderation/input_moderation.py @@ -7,7 +7,7 @@ from core.moderation.factory import ModerationFactory logger = logging.getLogger(__name__) -class ModerationFeature: +class InputModeration: def check(self, app_id: str, tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, diff --git a/api/core/app_runner/moderation_handler.py b/api/core/moderation/output_moderation.py similarity index 97% rename from api/core/app_runner/moderation_handler.py rename to api/core/moderation/output_moderation.py index b2098344c8..749ee431e8 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,7 @@ from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel -from core.application_queue_manager import PublishFrom +from core.app.app_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -18,7 +18,7 @@ class ModerationRule(BaseModel): config: dict[str, Any] -class OutputModerationHandler(BaseModel): +class OutputModeration(BaseModel): DEFAULT_BUFFER_SIZE: int = 300 tenant_id: str diff --git a/api/core/prompt/__init__.py b/api/core/prompt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 7519971ce7..6178453920 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode diff --git a/api/core/prompt/prompt_templates/__init__.py b/api/core/prompt/prompt_templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py similarity index 100% rename from api/core/prompt/advanced_prompt_templates.py rename to api/core/prompt/prompt_templates/advanced_prompt_templates.py diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/prompt_templates/baichuan_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_chat.json rename to api/core/prompt/prompt_templates/baichuan_chat.json diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/prompt_templates/baichuan_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/baichuan_completion.json rename to api/core/prompt/prompt_templates/baichuan_completion.json diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/prompt_templates/common_chat.json similarity index 100% rename from api/core/prompt/generate_prompts/common_chat.json rename to api/core/prompt/prompt_templates/common_chat.json diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/prompt_templates/common_completion.json similarity index 100% rename from api/core/prompt/generate_prompts/common_completion.json rename to api/core/prompt/prompt_templates/common_completion.json diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fcae0dc786..f3a03b01c7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from models.model import AppMode @@ -275,7 +275,7 @@ class SimplePromptTransform(PromptTransform): return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') # Open the JSON file and read its content diff --git a/api/core/prompt/utils/__init__.py b/api/core/prompt/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/utils/prompt_template_parser.py similarity index 100% rename from api/core/prompt/prompt_template.py rename to api/core/prompt/utils/prompt_template_parser.py diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0d81c419d6..139bfe15f3 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -9,7 +9,7 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/core/rag/retrieval/__init__.py b/api/core/rag/retrieval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/rag/retrieval/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/dataset_retrieval/agent/agent_llm_callback.py b/api/core/rag/retrieval/agent/agent_llm_callback.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/agent_llm_callback.py rename to api/core/rag/retrieval/agent/agent_llm_callback.py diff --git a/api/core/features/dataset_retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/fake_llm.py rename to api/core/rag/retrieval/agent/fake_llm.py diff --git a/api/core/features/dataset_retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py similarity index 91% rename from api/core/features/dataset_retrieval/agent/llm_chain.py rename to api/core/rag/retrieval/agent/llm_chain.py index e5155e15a0..d07ee0a582 100644 --- a/api/core/features/dataset_retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -7,8 +7,8 @@ from langchain.schema.language_model import BaseLanguageModel from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance diff --git a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py similarity index 98% rename from api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 59923202fd..8cc2e29743 100644 --- a/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from pydantic import root_validator from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.features.dataset_retrieval.agent.fake_llm import FakeLLM +from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/rag/retrieval/agent/output_parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py rename to api/core/rag/retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py similarity index 99% rename from api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py rename to api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index e69302bfd6..4d7d33038b 100644 --- a/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -13,7 +13,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool from core.entities.application_entities import ModelConfigEntity -from core.features.dataset_retrieval.agent.llm_chain import LLMChain +from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py similarity index 92% rename from api/core/features/dataset_retrieval/agent_based_dataset_executor.py rename to api/core/rag/retrieval/agent_based_dataset_executor.py index 588ccc91f5..f1ccf986e9 100644 --- a/api/core/features/dataset_retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -10,10 +10,10 @@ from pydantic import BaseModel, Extra from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages -from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback -from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent +from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/features/dataset_retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py similarity index 98% rename from api/core/features/dataset_retrieval/dataset_retrieval.py rename to api/core/rag/retrieval/dataset_retrieval.py index 3e54d8644d..07682389d6 100644 --- a/api/core/features/dataset_retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -5,7 +5,7 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -15,7 +15,7 @@ from extensions.ext_database import db from models.dataset import Dataset -class DatasetRetrievalFeature: +class DatasetRetrieval: def retrieve(self, tenant_id: str, model_config: ModelConfigEntity, config: DatasetEntity, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 30128c4dca..629ed23613 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -4,7 +4,7 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom -from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool @@ -30,7 +30,7 @@ class DatasetRetrieverTool(Tool): if retrieve_config is None: return [] - feature = DatasetRetrievalFeature() + feature = DatasetRetrieval() # save original retrieve strategy, and set retrieve strategy to SINGLE # Agent only support SINGLE mode diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 74dc8d5112..f5f3ba2540 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,4 +1,4 @@ -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db diff --git a/api/models/model.py b/api/models/model.py index 8d286d3482..235f77abc3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -310,22 +310,28 @@ class AppModelConfig(db.Model): def from_model_config_dict(self, model_config: dict): self.opening_statement = model_config['opening_statement'] - self.suggested_questions = json.dumps(model_config['suggested_questions']) - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) + self.suggested_questions = json.dumps(model_config['suggested_questions']) \ + if model_config.get('suggested_questions') else None + self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ + if model_config.get('suggested_questions_after_answer') else None self.speech_to_text = json.dumps(model_config['speech_to_text']) \ if model_config.get('speech_to_text') else None self.text_to_speech = json.dumps(model_config['text_to_speech']) \ if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) + self.more_like_this = json.dumps(model_config['more_like_this']) \ + if model_config.get('more_like_this') else None self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ if model_config.get('sensitive_word_avoidance') else None self.external_data_tools = json.dumps(model_config['external_data_tools']) \ if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) - self.user_input_form = json.dumps(model_config['user_input_form']) + self.model = json.dumps(model_config['model']) \ + if model_config.get('model') else None + self.user_input_form = json.dumps(model_config['user_input_form']) \ + if model_config.get('user_input_form') else None self.dataset_query_variable = model_config.get('dataset_query_variable') self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) + self.agent_mode = json.dumps(model_config['agent_mode']) \ + if model_config.get('agent_mode') else None self.retriever_resource = json.dumps(model_config['retriever_resource']) \ if model_config.get('retriever_resource') else None self.prompt_type = model_config.get('prompt_type', 'simple') diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 1e893e0eca..213df26222 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,7 +1,7 @@ import copy -from core.prompt.advanced_prompt_templates import ( +from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c1e0ecebe8..789d74ed2c 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,8 +1,8 @@ -from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator -from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator -from core.apps.app_config_validators.chat_app import ChatAppConfigValidator -from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator -from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator +from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator +from core.app.agent_chat.config_validator import AgentChatAppConfigValidator +from core.app.chat.config_validator import ChatAppConfigValidator +from core.app.completion.config_validator import CompletionAppConfigValidator +from core.app.workflow.config_validator import WorkflowAppConfigValidator from models.model import AppMode diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 9acd62b997..8a9639e521 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -4,8 +4,8 @@ from typing import Any, Union from sqlalchemy import and_ -from core.application_manager import ApplicationManager -from core.apps.config_validators.model import ModelValidator +from core.app.app_manager import AppManager +from core.app.validators.model_validator import ModelValidator from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db @@ -137,7 +137,7 @@ class CompletionService: user ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, @@ -193,7 +193,7 @@ class CompletionService: message.files, app_model_config ) - application_manager = ApplicationManager() + application_manager = AppManager() return application_manager.generate( tenant_id=app_model.tenant_id, app_id=app_model.id, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ac3df380b2..1a0213799e 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account diff --git a/api/services/message_service.py b/api/services/message_service.py index ad2ff60f6b..20918a8781 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from core.generator.llm_generator import LLMGenerator +from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index fb6cf1fd5a..f384855e7a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,7 +1,7 @@ import json from typing import Optional -from core.application_manager import ApplicationManager +from core.app.app_manager import AppManager from core.entities.application_entities import ( DatasetEntity, DatasetRetrieveConfigEntity, @@ -111,7 +111,7 @@ class WorkflowConverter: new_app_mode = self._get_new_app_mode(app_model) # convert app model config - application_manager = ApplicationManager() + application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, app_model_config_dict=app_model_config.to_dict(), diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 95f1e30b44..69acb23681 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -8,7 +8,7 @@ from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation From 9467fe9aa9f14a111816abc739fdecfd7c043d84 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 17:34:18 +0800 Subject: [PATCH 085/200] lint fix --- api/core/agent/base_agent_runner.py | 2 +- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- api/core/app/agent_chat/app_runner.py | 6 ++--- api/core/app/agent_chat/config_validator.py | 3 +-- api/core/app/app_manager.py | 4 ++-- .../app/app_orchestration_config_converter.py | 23 +++++++++++++++---- api/core/app/base_app_runner.py | 6 ++--- api/core/app/chat/app_runner.py | 4 ++-- api/core/app/completion/app_runner.py | 4 ++-- api/core/app/generate_task_pipeline.py | 2 +- api/core/llm_generator/llm_generator.py | 6 ++--- .../suggested_questions_after_answer.py | 1 + api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/simple_prompt_transform.py | 2 +- api/core/rag/retrieval/agent/llm_chain.py | 2 +- .../agent/multi_dataset_router_agent.py | 2 +- .../retrieval/agent_based_dataset_executor.py | 6 ++--- api/core/rag/retrieval/dataset_retrieval.py | 2 +- 19 files changed, 47 insertions(+), 34 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0658124d14..1474c6a475 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,8 +5,8 @@ from datetime import datetime from mimetypes import guess_extension from typing import Optional, Union, cast -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager +from core.app.base_app_runner import AppRunner from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 152e445795..5650113f47 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -3,9 +3,9 @@ import re from collections.abc import Generator from typing import Literal, Union +from core.agent.base_agent_runner import BaseAgentRunner from core.app.app_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit -from core.agent.base_agent_runner import BaseAgentRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 0cf0d3762c..9b238bf232 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -3,8 +3,8 @@ import logging from collections.abc import Generator from typing import Any, Union -from core.app.app_queue_manager import PublishFrom from core.agent.base_agent_runner import BaseAgentRunner +from core.app.app_queue_manager import PublishFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/app/agent_chat/app_runner.py b/api/core/app/agent_chat/app_runner.py index b046e935a5..38789348ad 100644 --- a/api/core/app/agent_chat/app_runner.py +++ b/api/core/app/agent_chat/app_runner.py @@ -1,11 +1,11 @@ import logging from typing import cast -from core.app.base_app_runner import AppRunner -from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.agent.cot_agent_runner import CotAgentRunner from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner +from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/agent_chat/config_validator.py index 6596b19f99..82bc40bd9b 100644 --- a/api/core/app/agent_chat/config_validator.py +++ b/api/core/app/agent_chat/config_validator.py @@ -1,6 +1,5 @@ import uuid -from core.entities.agent_entities import PlanningStrategy from core.app.validators.dataset_retrieval import DatasetValidator from core.app.validators.external_data_fetch import ExternalDataFetchValidator from core.app.validators.file_upload import FileUploadValidator @@ -13,9 +12,9 @@ from core.app.validators.speech_to_text import SpeechToTextValidator from core.app.validators.suggested_questions import SuggestedQuestionsValidator from core.app.validators.text_to_speech import TextToSpeechValidator from core.app.validators.user_input_form import UserInputFormValidator +from core.entities.agent_entities import PlanningStrategy from models.model import AppMode - OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 0819ed864b..86c8d2cfc7 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -8,11 +8,11 @@ from typing import Any, Optional, Union, cast from flask import Flask, current_app from pydantic import ValidationError -from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter from core.app.agent_chat.app_runner import AgentChatAppRunner +from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.chat.app_runner import ChatAppRunner from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.entities.application_entities import ( ApplicationGenerateEntity, InvokeFrom, diff --git a/api/core/app/app_orchestration_config_converter.py b/api/core/app/app_orchestration_config_converter.py index ddf49949a3..1d429ee6d9 100644 --- a/api/core/app/app_orchestration_config_converter.py +++ b/api/core/app/app_orchestration_config_converter.py @@ -1,11 +1,24 @@ from typing import cast -from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \ - TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \ - ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \ - AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity +from core.entities.application_entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + AgentEntity, + AgentPromptEntity, + AgentToolEntity, + AppOrchestrationConfigEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + SensitiveWordAvoidanceEntity, + TextToSpeechEntity, + VariableEntity, +) from core.entities.model_entities import ModelStatus -from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/app/base_app_runner.py b/api/core/app/base_app_runner.py index 788e3f91a3..2760d04180 100644 --- a/api/core/app/base_app_runner.py +++ b/api/core/app/base_app_runner.py @@ -3,6 +3,8 @@ from collections.abc import Generator from typing import Optional, Union, cast from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.entities.application_entities import ( ApplicationGenerateEntity, AppOrchestrationConfigEntity, @@ -11,10 +13,7 @@ from core.entities.application_entities import ( ModelConfigEntity, PromptTemplateEntity, ) -from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature -from core.moderation.input_moderation import InputModeration from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +21,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation diff --git a/api/core/app/chat/app_runner.py b/api/core/app/chat/app_runner.py index a1613e37a2..a1eccab13a 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,10 +10,10 @@ from core.entities.application_entities import ( InvokeFrom, ModelConfigEntity, ) -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, AppMode, Conversation, Message diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py index 34c6a5156f..3ac182b34e 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/completion/app_runner.py @@ -1,8 +1,8 @@ import logging from typing import Optional -from core.app.base_app_runner import AppRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, @@ -10,10 +10,10 @@ from core.entities.application_entities import ( InvokeFrom, ModelConfigEntity, ) -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, AppMode, Conversation, Message diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 6d52fa7348..dc6ea2db79 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -6,7 +6,6 @@ from typing import Optional, Union, cast from pydantic import BaseModel -from core.moderation.output_moderation import ModerationRule, OutputModeration from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import ( @@ -35,6 +34,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.moderation.output_moderation import ModerationRule, OutputModeration from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6ce70df703..1a6b71fb0a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -3,14 +3,14 @@ import logging from langchain.schema import OutputParserException +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser -from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index ad30bcfa07..1b955c6edd 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -5,6 +5,7 @@ from typing import Any from langchain.schema import BaseOutputParser from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.model_runtime.errors.invoke import InvokeError class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 6178453920..6d0a1d31f5 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,9 +15,9 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index f3a03b01c7..af7b695bb3 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -15,8 +15,8 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index d07ee0a582..087b7bfa2c 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -7,9 +7,9 @@ from langchain.schema.language_model import BaseLanguageModel from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.fake_llm import FakeLLM -from core.model_manager import ModelInstance class LLMChain(LCLLMChain): diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 8cc2e29743..41a0c54041 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -12,9 +12,9 @@ from pydantic import root_validator from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages -from core.rag.retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool +from core.rag.retrieval.agent.fake_llm import FakeLLM class MultiDatasetRouterAgent(OpenAIFunctionsAgent): diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index f1ccf986e9..7fabf71bed 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -10,13 +10,13 @@ from pydantic import BaseModel, Extra from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages +from core.helper import moderation +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.errors.invoke import InvokeError from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.helper import moderation -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.errors.invoke import InvokeError from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 07682389d6..21e16c4162 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -5,10 +5,10 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity -from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db From 8a8882ed8d09882de1c02a55c4b35bdf0eee9dcd Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:03:03 +0800 Subject: [PATCH 086/200] move workflow_id to app --- api/constants/model_template.py | 11 +- api/controllers/console/app/workflow.py | 8 +- api/core/app/chat/app_runner.py | 81 ++--------- api/core/app/completion/app_runner.py | 134 +++--------------- api/fields/workflow_fields.py | 5 +- .../versions/b289e2408ee2_add_workflow.py | 5 +- api/models/model.py | 22 ++- api/models/workflow.py | 10 ++ api/services/app_service.py | 104 +++++++++----- api/services/workflow/workflow_converter.py | 54 ++++--- api/services/workflow_service.py | 39 ++--- 11 files changed, 170 insertions(+), 303 deletions(-) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 61aab64d8a..c8aaba23cb 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,8 +7,7 @@ default_app_templates = { 'mode': AppMode.WORKFLOW.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': {} + } }, # chat default mode @@ -34,14 +33,6 @@ default_app_templates = { 'mode': AppMode.ADVANCED_CHAT.value, 'enable_site': True, 'enable_api': True - }, - 'model_config': { - 'model': { - "provider": "openai", - "name": "gpt-4", - "mode": "chat", - "completion_params": {} - } } }, diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4fcf8daf6e..54585d8519 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -41,10 +41,16 @@ class DraftWorkflowApi(Resource): """ parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') + parser.add_argument('features', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) + workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + account=current_user + ) return { "result": "success" diff --git a/api/core/app/chat/app_runner.py b/api/core/app/chat/app_runner.py index a1eccab13a..4c8018572e 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/chat/app_runner.py @@ -1,21 +1,17 @@ import logging -from typing import Optional from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -145,18 +141,23 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=app_orchestration_config.dataset, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback, memory=memory ) @@ -212,57 +213,3 @@ class ChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) - \ No newline at end of file diff --git a/api/core/app/completion/app_runner.py b/api/core/app/completion/app_runner.py index 3ac182b34e..ab2f40ad9a 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/completion/app_runner.py @@ -1,21 +1,16 @@ import logging -from typing import Optional -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.app_queue_manager import AppQueueManager from core.app.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ( ApplicationGenerateEntity, - DatasetEntity, - InvokeFrom, - ModelConfigEntity, ) -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db -from models.model import App, AppMode, Conversation, Message +from models.model import App, Message logger = logging.getLogger(__name__) @@ -27,13 +22,11 @@ class CompletionAppRunner(AppRunner): def run(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: AppQueueManager, - conversation: Conversation, message: Message) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager - :param conversation: conversation :param message: message :return: """ @@ -61,30 +54,15 @@ class CompletionAppRunner(AppRunner): query=query ) - memory = None - if application_generate_entity.conversation_id: - # get memory of conversation (read-only) - model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model - ) - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) - # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, model_config=app_orchestration_config.model_config, prompt_template_entity=app_orchestration_config.prompt_template, inputs=inputs, files=files, - query=query, - memory=memory + query=query ) # moderation @@ -107,30 +85,6 @@ class CompletionAppRunner(AppRunner): ) return - if query: - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from - ) - - if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER - ) - self.direct_output( - queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, - prompt_messages=prompt_messages, - text=annotation_reply.content, - stream=application_generate_entity.stream - ) - return - # fill in variable inputs from external data tools if exists external_data_tools = app_orchestration_config.external_data_variables if external_data_tools: @@ -145,19 +99,27 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: - context = self.retrieve_dataset_context( + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager, + app_record.id, + message.id, + application_generate_entity.user_id, + application_generate_entity.invoke_from + ) + + dataset_config = app_orchestration_config.dataset + if dataset_config and dataset_config.retrieve_config.query_variable: + query = inputs.get(dataset_config.retrieve_config.query_variable, "") + + dataset_retrieval = DatasetRetrieval() + context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - app_record=app_record, - queue_manager=queue_manager, model_config=app_orchestration_config.model_config, - show_retrieve_source=app_orchestration_config.show_retrieve_source, - dataset_config=app_orchestration_config.dataset, - message=message, - inputs=inputs, + config=dataset_config, query=query, - user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - memory=memory + show_retrieve_source=app_orchestration_config.show_retrieve_source, + hit_callback=hit_callback ) # reorganize all inputs and template to prompt messages @@ -170,8 +132,7 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, - context=context, - memory=memory + context=context ) # check hosting moderation @@ -210,57 +171,4 @@ class CompletionAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream ) - - def retrieve_dataset_context(self, tenant_id: str, - app_record: App, - queue_manager: AppQueueManager, - model_config: ModelConfigEntity, - dataset_config: DatasetEntity, - show_retrieve_source: bool, - message: Message, - inputs: dict, - query: str, - user_id: str, - invoke_from: InvokeFrom, - memory: Optional[TokenBufferMemory] = None) -> Optional[str]: - """ - Retrieve dataset context - :param tenant_id: tenant id - :param app_record: app record - :param queue_manager: queue manager - :param model_config: model config - :param dataset_config: dataset config - :param show_retrieve_source: show retrieve source - :param message: message - :param inputs: inputs - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :param memory: memory - :return: - """ - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager, - app_record.id, - message.id, - user_id, - invoke_from - ) - - # TODO - if (app_record.mode == AppMode.COMPLETION.value and dataset_config - and dataset_config.retrieve_config.query_variable): - query = inputs.get(dataset_config.retrieve_config.query_variable, "") - - dataset_retrieval = DatasetRetrieval() - return dataset_retrieval.retrieve( - tenant_id=tenant_id, - model_config=model_config, - config=dataset_config, - query=query, - invoke_from=invoke_from, - show_retrieve_source=show_retrieve_source, - hit_callback=hit_callback, - memory=memory - ) \ No newline at end of file diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index decdc0567f..bcb2c318c6 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,5 +1,3 @@ -import json - from flask_restful import fields from fields.member_fields import simple_account_fields @@ -7,7 +5,8 @@ from libs.helper import TimestampField workflow_fields = { 'id': fields.String, - 'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), + 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'), + 'features': fields.Nested(simple_account_fields, attribute='features_dict'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 5f7ddc7d68..5ae1e65611 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -97,6 +97,7 @@ def upgrade(): sa.Column('type', sa.String(length=255), nullable=False), sa.Column('version', sa.String(length=255), nullable=False), sa.Column('graph', sa.Text(), nullable=True), + sa.Column('features', sa.Text(), nullable=True), sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_by', postgresql.UUID(), nullable=True), @@ -106,7 +107,7 @@ def upgrade(): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) with op.batch_alter_table('messages', schema=None) as batch_op: @@ -120,7 +121,7 @@ def downgrade(): with op.batch_alter_table('messages', schema=None) as batch_op: batch_op.drop_column('workflow_run_id') - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('workflow_id') with op.batch_alter_table('workflows', schema=None) as batch_op: diff --git a/api/models/model.py b/api/models/model.py index 235f77abc3..c6409c61ed 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -63,6 +63,7 @@ class App(db.Model): icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(UUID, nullable=True) + workflow_id = db.Column(UUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) @@ -85,6 +86,14 @@ class App(db.Model): AppModelConfig.id == self.app_model_config_id).first() return app_model_config + @property + def workflow(self): + if self.workflow_id: + from api.models.workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + + return None + @property def api_base_url(self): return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] @@ -176,7 +185,6 @@ class AppModelConfig(db.Model): dataset_configs = db.Column(db.Text) external_data_tools = db.Column(db.Text) file_upload = db.Column(db.Text) - workflow_id = db.Column(UUID) @property def app(self): @@ -276,14 +284,6 @@ class AppModelConfig(db.Model): "image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} - @property - def workflow(self): - if self.workflow_id: - from api.models.workflow import Workflow - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() - - return None - def to_dict(self) -> dict: return { "opening_statement": self.opening_statement, @@ -343,7 +343,6 @@ class AppModelConfig(db.Model): if model_config.get('dataset_configs') else None self.file_upload = json.dumps(model_config.get('file_upload')) \ if model_config.get('file_upload') else None - self.workflow_id = model_config.get('workflow_id') return self def copy(self): @@ -368,8 +367,7 @@ class AppModelConfig(db.Model): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload, - workflow_id=self.workflow_id + file_upload=self.file_upload ) return new_app_model_config diff --git a/api/models/workflow.py b/api/models/workflow.py index 316d3e623e..c38c1dd610 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,3 +1,4 @@ +import json from enum import Enum from typing import Union @@ -106,6 +107,7 @@ class Workflow(db.Model): type = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) + features = db.Column(db.Text) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_by = db.Column(UUID) @@ -119,6 +121,14 @@ class Workflow(db.Model): def updated_by_account(self): return Account.query.get(self.updated_by) + @property + def graph_dict(self): + return self.graph if not self.graph else json.loads(self.graph) + + @property + def features_dict(self): + return self.features if not self.features else json.loads(self.features) + class WorkflowRunTriggeredFrom(Enum): """ diff --git a/api/services/app_service.py b/api/services/app_service.py index 374727d2d4..7dd5d770ea 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -64,8 +64,8 @@ class AppService: app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template['model_config'] - if 'model' in default_model_config: + default_model_config = app_template.get('model_config') + if default_model_config and 'model' in default_model_config: # get model provider model_manager = ModelManager() @@ -110,12 +110,15 @@ class AppService: db.session.add(app) db.session.flush() - app_model_config = AppModelConfig(**default_model_config) - app_model_config.app_id = app.id - db.session.add(app_model_config) - db.session.flush() + if default_model_config: + app_model_config = AppModelConfig(**default_model_config) + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() - app.app_model_config_id = app_model_config.id + app.app_model_config_id = app_model_config.id + + db.session.commit() app_was_created.send(app, account=account) @@ -135,16 +138,22 @@ class AppService: app_data = import_data.get('app') model_config_data = import_data.get('model_config') - workflow_graph = import_data.get('workflow_graph') + workflow = import_data.get('workflow') - if not app_data or not model_config_data: - raise ValueError("Missing app or model_config in data argument") + if not app_data: + raise ValueError("Missing app in data argument") app_mode = AppMode.value_of(app_data.get('mode')) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - if not workflow_graph: - raise ValueError("Missing workflow_graph in data argument " - "when mode is advanced-chat or workflow") + if not workflow: + raise ValueError("Missing workflow in data argument " + "when app mode is advanced-chat or workflow") + elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT]: + if not model_config_data: + raise ValueError("Missing model_config in data argument " + "when app mode is chat or agent-chat") + else: + raise ValueError("Invalid app mode") app = App( tenant_id=tenant_id, @@ -161,26 +170,32 @@ class AppService: db.session.add(app) db.session.commit() - if workflow_graph: - # init draft workflow - workflow_service = WorkflowService() - workflow_service.sync_draft_workflow(app, workflow_graph, account) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - app_was_created.send(app, account=account) - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + if workflow: + # init draft workflow + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow.get('graph'), + features=workflow.get('features'), + account=account + ) + + if model_config_data: + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_model_config_was_updated.send( + app, + app_model_config=app_model_config + ) return app @@ -190,7 +205,7 @@ class AppService: :param app: App instance :return: """ - app_model_config = app.app_model_config + app_mode = AppMode.value_of(app.mode) export_data = { "app": { @@ -198,16 +213,27 @@ class AppService: "mode": app.mode, "icon": app.icon, "icon_background": app.icon_background - }, - "model_config": app_model_config.to_dict(), + } } - if app_model_config.workflow_id: - export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) + if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app.workflow_id: + workflow = app.workflow + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } + else: + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app) + export_data['workflow'] = { + "graph": workflow.graph_dict, + "features": workflow.features_dict + } else: - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app) - export_data['workflow_graph'] = json.loads(workflow.graph) + app_model_config = app.app_model_config + + export_data['model_config'] = app_model_config.to_dict() return yaml.dump(export_data) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f384855e7a..6c0182dd9e 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -44,13 +44,10 @@ class WorkflowConverter: :param account: Account :return: new App instance """ - # get original app config - app_model_config = app_model.app_model_config - # convert app model config workflow = self.convert_app_model_config_to_workflow( app_model=app_model, - app_model_config=app_model_config, + app_model_config=app_model.app_model_config, account_id=account.id ) @@ -58,8 +55,9 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = app_model.name + '(workflow)' - new_app.mode = AppMode.CHAT.value \ + new_app.mode = AppMode.ADVANCED_CHAT.value \ if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.workflow_id = workflow.id new_app.icon = app_model.icon new_app.icon_background = app_model.icon_background new_app.enable_site = app_model.enable_site @@ -69,28 +67,6 @@ class WorkflowConverter: new_app.is_demo = False new_app.is_public = app_model.is_public db.session.add(new_app) - db.session.flush() - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = new_app.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - new_app.app_model_config_id = new_app_model_config.id db.session.commit() app_was_created.send(new_app, account=account) @@ -110,11 +86,13 @@ class WorkflowConverter: # get new app mode new_app_mode = self._get_new_app_mode(app_model) + app_model_config_dict = app_model_config.to_dict() + # convert app model config application_manager = AppManager() app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config.to_dict(), + app_model_config_dict=app_model_config_dict, skip_check=True ) @@ -177,6 +155,25 @@ class WorkflowConverter: graph = self._append_node(graph, end_node) + # features + if new_app_mode == AppMode.ADVANCED_CHAT: + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + else: + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + # create workflow record workflow = Workflow( tenant_id=app_model.tenant_id, @@ -184,6 +181,7 @@ class WorkflowConverter: type=WorkflowType.from_app_mode(new_app_mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account_id, created_at=app_model_config.created_at ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5a9234c70a..006bc44e41 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -33,29 +33,31 @@ class WorkflowService: """ Get published workflow """ - app_model_config = app_model.app_model_config - - if not app_model_config.workflow_id: + if not app_model.workflow_id: return None # fetch published workflow by workflow_id workflow = db.session.query(Workflow).filter( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.id == app_model_config.workflow_id + Workflow.id == app_model.workflow_id ).first() # return published workflow return workflow - - def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: + def sync_draft_workflow(self, app_model: App, + graph: dict, + features: dict, + account: Account) -> Workflow: """ Sync draft workflow """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) + # TODO validate features + # create draft workflow if not found if not workflow: workflow = Workflow( @@ -64,12 +66,14 @@ class WorkflowService: type=WorkflowType.from_app_mode(app_model.mode).value, version='draft', graph=json.dumps(graph), + features=json.dumps(features), created_by=account.id ) db.session.add(workflow) # update draft workflow if found else: workflow.graph = json.dumps(graph) + workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.utcnow() @@ -112,28 +116,7 @@ class WorkflowService: db.session.add(workflow) db.session.commit() - app_model_config = app_model.app_model_config - - # create new app model config record - new_app_model_config = app_model_config.copy() - new_app_model_config.id = None - new_app_model_config.app_id = app_model.id - new_app_model_config.external_data_tools = '' - new_app_model_config.model = '' - new_app_model_config.user_input_form = '' - new_app_model_config.dataset_query_variable = None - new_app_model_config.pre_prompt = None - new_app_model_config.agent_mode = '' - new_app_model_config.prompt_type = 'simple' - new_app_model_config.chat_prompt_config = '' - new_app_model_config.completion_prompt_config = '' - new_app_model_config.dataset_configs = '' - new_app_model_config.workflow_id = workflow.id - - db.session.add(new_app_model_config) - db.session.flush() - - app_model.app_model_config_id = new_app_model_config.id + app_model.workflow_id = workflow.id db.session.commit() # TODO update app related datasets From 7bff65304fd4e672e95ccacf700a85c6d9070497 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:20:27 +0800 Subject: [PATCH 087/200] add features structure validate --- api/controllers/console/app/model_config.py | 36 +------------------ .../app/advanced_chat/config_validator.py | 9 +++-- api/core/app/validators/moderation.py | 18 +++++----- api/core/app/workflow/config_validator.py | 9 +++-- api/services/app_model_config_service.py | 9 ----- api/services/workflow_service.py | 26 ++++++++++++-- 6 files changed, 49 insertions(+), 58 deletions(-) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index d822f859bc..1301d12da4 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,7 @@ import json from flask import request from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restful import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -137,38 +137,4 @@ class ModelConfigResource(Resource): return {'result': 'success'} -class FeaturesResource(Resource): - - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def put(self, app_model): - """Get app features""" - parser = reqparse.RequestParser() - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - args = parser.parse_args() - - model_configuration = AppModelConfigService.validate_features( - tenant_id=current_user.current_tenant_id, - config=args.get('features'), - app_mode=AppMode.value_of(app_model.mode) - ) - - # update config - app_model_config = app_model.app_model_config - app_model_config.from_model_config_dict(model_configuration) - db.session.commit() - - app_model_config_was_updated.send( - app_model, - app_model_config=app_model_config - ) - - return { - 'result': 'success' - } - - api.add_resource(ModelConfigResource, '/apps//model-config') -api.add_resource(FeaturesResource, '/apps//features') diff --git a/api/core/app/advanced_chat/config_validator.py b/api/core/app/advanced_chat/config_validator.py index 39c00c028e..a20198ef4a 100644 --- a/api/core/app/advanced_chat/config_validator.py +++ b/api/core/app/advanced_chat/config_validator.py @@ -9,12 +9,13 @@ from core.app.validators.text_to_speech import TextToSpeechValidator class AdvancedChatAppConfigValidator: @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: """ Validate for advanced chat app model config :param tenant_id: tenant id :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed """ related_config_keys = [] @@ -43,7 +44,11 @@ class AdvancedChatAppConfigValidator: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/validators/moderation.py b/api/core/app/validators/moderation.py index 4813385588..7a5dff55c9 100644 --- a/api/core/app/validators/moderation.py +++ b/api/core/app/validators/moderation.py @@ -7,7 +7,8 @@ logger = logging.getLogger(__name__) class ModerationValidator: @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ + -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = { "enabled": False @@ -23,13 +24,14 @@ class ModerationValidator: if not config["sensitive_word_avoidance"].get("type"): raise ValueError("sensitive_word_avoidance.type is required") - typ = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] + if not only_structure_validate: + typ = config["sensitive_word_avoidance"]["type"] + config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ModerationFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/workflow/config_validator.py b/api/core/app/workflow/config_validator.py index b76eabaeb5..e8381146a7 100644 --- a/api/core/app/workflow/config_validator.py +++ b/api/core/app/workflow/config_validator.py @@ -5,12 +5,13 @@ from core.app.validators.text_to_speech import TextToSpeechValidator class WorkflowAppConfigValidator: @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: """ Validate for workflow app model config :param tenant_id: tenant id :param config: app model config args + :param only_structure_validate: only validate the structure of the config """ related_config_keys = [] @@ -23,7 +24,11 @@ class WorkflowAppConfigValidator: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 789d74ed2c..a35b0dd36e 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -18,12 +18,3 @@ class AppModelConfigService: return CompletionAppConfigValidator.config_validate(tenant_id, config) else: raise ValueError(f"Invalid app mode: {app_mode}") - - @classmethod - def validate_features(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: - if app_mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppConfigValidator.config_validate(tenant_id, config) - elif app_mode == AppMode.WORKFLOW: - return WorkflowAppConfigValidator.config_validate(tenant_id, config) - else: - raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 006bc44e41..102c861733 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,6 +2,8 @@ import json from datetime import datetime from typing import Optional +from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator +from core.app.workflow.config_validator import WorkflowAppConfigValidator from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -56,7 +58,11 @@ class WorkflowService: # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) - # TODO validate features + # validate features structure + self.validate_features_structure( + app_model=app_model, + features=features + ) # create draft workflow if not found if not workflow: @@ -100,7 +106,7 @@ class WorkflowService: if not draft_workflow: raise ValueError('No valid workflow found.') - # TODO check if the workflow is valid, basic check + # TODO check if the workflow structure is valid # create new workflow workflow = Workflow( @@ -153,3 +159,19 @@ class WorkflowService: ) return new_app + + def validate_features_structure(self, app_model: App, features: dict) -> dict: + if app_model.mode == AppMode.ADVANCED_CHAT.value: + return AdvancedChatAppConfigValidator.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + elif app_model.mode == AppMode.WORKFLOW.value: + return WorkflowAppConfigValidator.config_validate( + tenant_id=app_model.tenant_id, + config=features, + only_structure_validate=True + ) + else: + raise ValueError(f"Invalid app mode: {app_model.mode}") From 9651a208a97b4f8da32611106bd47b93eafd30e3 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:20:31 +0800 Subject: [PATCH 088/200] lint fix --- api/services/app_model_config_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index a35b0dd36e..f2caeb14ff 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,8 +1,6 @@ -from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator from core.app.agent_chat.config_validator import AgentChatAppConfigValidator from core.app.chat.config_validator import ChatAppConfigValidator from core.app.completion.config_validator import CompletionAppConfigValidator -from core.app.workflow.config_validator import WorkflowAppConfigValidator from models.model import AppMode From 43b0440358886d2f94ca5cc714406d45ddc55972 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:58:30 +0800 Subject: [PATCH 089/200] support workflow features --- api/controllers/console/app/audio.py | 6 +- api/controllers/console/explore/audio.py | 14 +---- api/controllers/console/explore/parameter.py | 60 ++++++++++++++------ api/controllers/service_api/app/app.py | 51 ++++++++++++----- api/controllers/service_api/app/audio.py | 16 ++---- api/controllers/web/app.py | 49 +++++++++++----- api/controllers/web/audio.py | 16 +----- api/controllers/web/site.py | 4 -- api/core/file/message_file_parser.py | 6 +- api/core/memory/token_buffer_memory.py | 7 ++- api/models/model.py | 7 ++- api/models/workflow.py | 16 ++++++ api/services/app_service.py | 7 ++- api/services/audio_service.py | 49 ++++++++++++++-- 14 files changed, 211 insertions(+), 97 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 458fa5098f..c7f3a598ca 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -43,7 +43,7 @@ class ChatMessageAudioApi(Resource): try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None, ) @@ -83,9 +83,9 @@ class ChatMessageTextApi(Resource): def post(self, app_model): try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index dc546ce0dd..34ce1ec1ee 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -32,16 +32,12 @@ from services.errors.audio import ( class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=None ) @@ -76,16 +72,12 @@ class ChatAudioApi(InstalledAppResource): class ChatTextApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c4afb0b923..0239742a4a 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -4,9 +4,10 @@ from flask import current_app from flask_restful import fields, marshal_with from controllers.console import api +from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from models.model import AppModelConfig, InstalledApp +from models.model import AppModelConfig, InstalledApp, AppMode from models.tools import ApiToolProvider @@ -45,30 +46,55 @@ class AppParameterApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - app_model_config = app_model.app_model_config + + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" app_model_config: AppModelConfig = installed_app.app.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} @@ -77,7 +103,7 @@ class ExploreAppMetaApi(InstalledAppResource): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + "/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: @@ -94,12 +120,14 @@ class ExploreAppMetaApi(InstalledAppResource): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta -api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') + +api.add_resource(AppParameterApi, '/installed-apps//parameters', + endpoint='installed_app_parameters') api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index a3151fc4a2..76708716c2 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -4,9 +4,10 @@ from flask import current_app from flask_restful import fields, marshal_with, Resource from controllers.service_api import api +from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -46,31 +47,55 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMetaApi(Resource): @validate_app_token def get(self, app_model: App): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + return { + 'tool_icons': {} + } + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index f6cad501f0..57edab4090 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,18 +33,13 @@ from services.errors.audio import ( class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, - end_user=end_user.get_id() + end_user=end_user ) return response @@ -79,15 +74,16 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') + parser.add_argument('voice', type=str, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') args = parser.parse_args() try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=args['text'], - end_user=end_user.get_id(), - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + end_user=end_user, + voice=args.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 25492b1143..07ce098298 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,9 +4,10 @@ from flask import current_app from flask_restful import fields, marshal_with from controllers.web import api +from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, AppMode from models.tools import ApiToolProvider @@ -44,30 +45,52 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - app_model_config = app_model.app_model_config + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form + else: + app_model_config = app_model.app_model_config + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get('user_input_form', []) return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'text_to_speech': app_model_config.text_to_speech_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - 'more_like_this': app_model_config.more_like_this_dict, - 'user_input_form': app_model_config.user_input_form_list, - 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, - 'file_upload': app_model_config.file_upload_dict, + 'opening_statement': features_dict.get('opening_statement'), + 'suggested_questions': features_dict.get('suggested_questions', []), + 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', + {"enabled": False}), + 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), + 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), + 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), + 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), + 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), + 'user_input_form': user_input_form, + 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', + {"enabled": False, "type": "", "configs": []}), + 'file_upload': features_dict.get('file_upload', {"image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"] + }}), 'system_parameters': { 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') } } + class AppMeta(WebApiResource): def get(self, app_model: App, end_user): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config + if not app_model_config: + raise AppUnavailableError() + agent_config = app_model_config.agent_mode_dict or {} meta = { 'tool_icons': {} diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 4e677ae288..8b8ab8f090 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,16 +31,11 @@ from services.errors.audio import ( class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.speech_to_text_dict['enabled']: - raise AppUnavailableError() - file = request.files['file'] try: response = AudioService.transcript_asr( - tenant_id=app_model.tenant_id, + app_model=app_model, file=file, end_user=end_user ) @@ -74,17 +69,12 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - app_model_config: AppModelConfig = app_model.app_model_config - - if not app_model_config.text_to_speech_dict['enabled']: - raise AppUnavailableError() - try: response = AudioService.transcript_tts( - tenant_id=app_model.tenant_id, + app_model=app_model, text=request.form['text'], end_user=end_user.external_user_id, - voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form.get('voice'), streaming=False ) diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index d8e2d59707..bf3536d276 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -83,7 +83,3 @@ class AppSiteInfo: 'remove_webapp_brand': remove_webapp_brand, 'replace_webapp_logo': replace_webapp_logo, } - - if app.enable_site and site.prompt_public: - app_model_config = app.app_model_config - self.model_config = app_model_config diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 1b7b8b87da..c132073578 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -96,16 +96,16 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_upload_config: Optional[dict]) -> list[FileObj]: """ transform message files :param files: - :param app_model_config: + :param file_upload_config: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict) + type_file_objs = self._to_file_objs(files, file_upload_config) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4d44ac3818..f9200dcc71 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -10,7 +10,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db -from models.model import Conversation, Message +from models.model import Conversation, Message, AppMode class TokenBufferMemory: @@ -44,7 +44,10 @@ class TokenBufferMemory: files = message.message_files if files: file_objs = message_file_parser.transform_message_files( - files, message.app_model_config + files, + message.app_model_config.file_upload_dict + if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + else message.workflow_run.workflow.features_dict.get('file_upload', {}) ) if not file_objs: diff --git a/api/models/model.py b/api/models/model.py index c6409c61ed..e514ea729b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -82,9 +82,10 @@ class App(db.Model): @property def app_model_config(self) -> Optional['AppModelConfig']: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() - return app_model_config + if self.app_model_config_id: + return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + + return None @property def workflow(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index c38c1dd610..ff4e944e29 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -129,6 +129,22 @@ class Workflow(db.Model): def features_dict(self): return self.features if not self.features else json.loads(self.features) + def user_input_form(self): + # get start node from graph + if not self.graph: + return [] + + graph_dict = self.graph_dict + if 'nodes' not in graph_dict: + return [] + + start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) + if not start_node: + return [] + + # get user_input_form from start node + return start_node.get('variables', []) + class WorkflowRunTriggeredFrom(Enum): """ diff --git a/api/services/app_service.py b/api/services/app_service.py index 7dd5d770ea..e0a7835cb7 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -175,12 +175,17 @@ class AppService: if workflow: # init draft workflow workflow_service = WorkflowService() - workflow_service.sync_draft_workflow( + draft_workflow = workflow_service.sync_draft_workflow( app_model=app, graph=workflow.get('graph'), features=workflow.get('features'), account=account ) + workflow_service.publish_workflow( + app_model=app, + account=account, + draft_workflow=draft_workflow + ) if model_config_data: app_model_config = AppModelConfig() diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9fe65df6f..0123666644 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,6 +5,7 @@ from werkzeug.datastructures import FileStorage from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from models.model import AppModelConfig, App, AppMode from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, @@ -20,7 +21,21 @@ ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] class AudioService: @classmethod - def transcript_asr(cls, tenant_id: str, file: FileStorage, end_user: Optional[str] = None): + def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("Speech to text is not enabled") + + features_dict = workflow.features_dict + if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + raise ValueError("Speech to text is not enabled") + else: + app_model_config: AppModelConfig = app_model.app_model_config + + if not app_model_config.speech_to_text_dict['enabled']: + raise ValueError("Speech to text is not enabled") + if file is None: raise NoAudioUploadedServiceError() @@ -37,7 +52,7 @@ class AudioService: model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: @@ -49,17 +64,41 @@ class AudioService: return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, tenant_id: str, text: str, voice: str, streaming: bool, end_user: Optional[str] = None): + def transcript_tts(cls, app_model: App, text: str, streaming: bool, end_user: Optional[str] = None): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("TTS is not enabled") + + features_dict = workflow.features_dict + if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + raise ValueError("TTS is not enabled") + + voice = features_dict['text_to_speech'].get('voice') + else: + text_to_speech_dict = app_model.app_model_config.text_to_speech_dict + + if not text_to_speech_dict.get('enabled'): + raise ValueError("TTS is not enabled") + + voice = text_to_speech_dict.get('voice'), + model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() try: - return model_instance.invoke_tts(content_text=text.strip(), user=end_user, streaming=streaming, tenant_id=tenant_id, voice=voice) + return model_instance.invoke_tts( + content_text=text.strip(), + user=end_user, + streaming=streaming, + tenant_id=app_model.tenant_id, + voice=voice + ) except Exception as e: raise e From 15c7e0ec2f2778f92c352b1373d0273afe6689f8 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 29 Feb 2024 22:58:33 +0800 Subject: [PATCH 090/200] lint fix --- api/controllers/console/explore/audio.py | 1 - api/controllers/console/explore/parameter.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/web/audio.py | 2 +- api/core/memory/token_buffer_memory.py | 2 +- api/services/audio_service.py | 2 +- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 34ce1ec1ee..f03663f1a2 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,6 @@ from controllers.console.app.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import AppModelConfig from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0239742a4a..9c0fca57f2 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -7,7 +7,7 @@ from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from models.model import AppModelConfig, InstalledApp, AppMode +from models.model import AppMode, AppModelConfig, InstalledApp from models.tools import ApiToolProvider diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 57edab4090..15c0a153b8 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -20,7 +20,7 @@ from controllers.service_api.app.error import ( from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig, EndUser +from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8b8ab8f090..e0074c452f 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -19,7 +19,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig +from models.model import App from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index f9200dcc71..00813faef7 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -10,7 +10,7 @@ from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db -from models.model import Conversation, Message, AppMode +from models.model import AppMode, Conversation, Message class TokenBufferMemory: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 0123666644..7a658487f8 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,7 +5,7 @@ from werkzeug.datastructures import FileStorage from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from models.model import AppModelConfig, App, AppMode +from models.model import App, AppMode, AppModelConfig from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, From 3f5d1a79c664109650b435bfeee9151afff1a798 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:18 +0800 Subject: [PATCH 091/200] refactor apps --- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 6 +- api/controllers/console/app/conversation.py | 8 +- api/controllers/console/app/message.py | 4 +- api/controllers/console/app/statistic.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/console/explore/message.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/web/completion.py | 2 +- api/controllers/web/message.py | 2 +- api/core/agent/base_agent_runner.py | 47 +- api/core/agent/cot_agent_runner.py | 33 +- api/core/agent/entities.py | 61 +++ api/core/agent/fc_agent_runner.py | 14 +- .../app/advanced_chat/config_validator.py | 59 --- .../{advanced_chat => app_config}/__init__.py | 0 .../app/app_config/base_app_config_manager.py | 73 +++ .../common}/__init__.py | 0 .../sensitive_word_avoidance}/__init__.py | 0 .../sensitive_word_avoidance/manager.py} | 19 +- .../easy_ui_based_app}/__init__.py | 0 .../easy_ui_based_app/agent}/__init__.py | 0 .../easy_ui_based_app/agent/manager.py | 79 ++++ .../easy_ui_based_app/dataset}/__init__.py | 0 .../easy_ui_based_app/dataset/manager.py} | 87 +++- .../model_config/__init__.py | 0 .../model_config/converter.py | 104 +++++ .../model_config/manager.py} | 36 +- .../prompt_template/__init__.py | 0 .../prompt_template/manager.py} | 59 ++- .../easy_ui_based_app/variables/__init__.py | 0 .../easy_ui_based_app/variables/manager.py | 184 ++++++++ .../app_config/entities.py} | 167 ++----- api/core/app/app_config/features/__init__.py | 0 .../features/file_upload/__init__.py | 0 .../features/file_upload/manager.py} | 26 +- .../features/more_like_this/__init__.py | 0 .../features/more_like_this/manager.py} | 15 +- .../features/opening_statement/__init__.py | 0 .../features/opening_statement/manager.py} | 18 +- .../features/retrieval_resource/__init__.py | 0 .../features/retrieval_resource/manager.py} | 10 +- .../features/speech_to_text/__init__.py | 0 .../features/speech_to_text/manager.py} | 15 +- .../__init__.py | 0 .../manager.py} | 18 +- .../features/text_to_speech/__init__.py | 0 .../features/text_to_speech/manager.py} | 22 +- .../workflow_ui_based_app/__init__.py | 0 .../variables/__init__.py | 0 .../variables/manager.py | 22 + api/core/app/app_manager.py | 198 +++++--- .../app/app_orchestration_config_converter.py | 421 ------------------ api/core/app/app_queue_manager.py | 4 +- api/core/app/apps/__init__.py | 0 api/core/app/apps/advanced_chat/__init__.py | 0 .../apps/advanced_chat/app_config_manager.py | 94 ++++ api/core/app/apps/agent_chat/__init__.py | 0 .../agent_chat/app_config_manager.py} | 114 +++-- .../app/{ => apps}/agent_chat/app_runner.py | 69 +-- api/core/app/{ => apps}/base_app_runner.py | 35 +- api/core/app/apps/chat/__init__.py | 0 api/core/app/apps/chat/app_config_manager.py | 135 ++++++ api/core/app/{ => apps}/chat/app_runner.py | 61 +-- api/core/app/apps/completion/__init__.py | 0 .../app/apps/completion/app_config_manager.py | 118 +++++ .../app/{ => apps}/completion/app_runner.py | 53 +-- api/core/app/apps/workflow/__init__.py | 0 .../app/apps/workflow/app_config_manager.py | 71 +++ api/core/app/chat/config_validator.py | 82 ---- api/core/app/completion/config_validator.py | 67 --- api/core/app/entities/__init__.py | 0 api/core/app/entities/app_invoke_entities.py | 111 +++++ api/core/{ => app}/entities/queue_entities.py | 0 .../annotation_reply/annotation_reply.py | 2 +- .../hosting_moderation/hosting_moderation.py | 7 +- api/core/app/generate_task_pipeline.py | 22 +- .../app/validators/external_data_fetch.py | 39 -- api/core/app/validators/user_input_form.py | 61 --- api/core/app/workflow/config_validator.py | 39 -- .../agent_loop_gather_callback_handler.py | 262 ----------- .../callback_handler/entity/agent_loop.py | 23 - .../index_tool_callback_handler.py | 2 +- .../external_data_tool/external_data_fetch.py | 2 +- api/core/file/file_obj.py | 5 +- api/core/file/message_file_parser.py | 35 +- api/core/helper/moderation.py | 4 +- api/core/memory/token_buffer_memory.py | 20 +- api/core/moderation/input_moderation.py | 10 +- api/core/prompt/advanced_prompt_transform.py | 15 +- api/core/prompt/prompt_transform.py | 6 +- api/core/prompt/simple_prompt_transform.py | 14 +- .../rag/retrieval/agent/agent_llm_callback.py | 101 ----- api/core/rag/retrieval/agent/llm_chain.py | 7 +- .../agent/multi_dataset_router_agent.py | 6 +- .../structed_multi_dataset_router_agent.py | 4 +- .../retrieval/agent_based_dataset_executor.py | 8 +- api/core/rag/retrieval/dataset_retrieval.py | 5 +- api/core/tools/tool/dataset_retriever_tool.py | 3 +- .../deduct_quota_when_messaeg_created.py | 8 +- ...vider_last_used_at_when_messaeg_created.py | 8 +- api/models/model.py | 12 + api/models/workflow.py | 2 +- api/services/app_model_config_service.py | 12 +- api/services/completion_service.py | 147 ++---- api/services/workflow/workflow_converter.py | 46 +- api/services/workflow_service.py | 8 +- .../prompt/test_advanced_prompt_transform.py | 10 +- .../core/prompt/test_prompt_transform.py | 2 +- .../prompt/test_simple_prompt_transform.py | 6 +- .../workflow/test_workflow_converter.py | 2 +- 111 files changed, 1979 insertions(+), 1819 deletions(-) create mode 100644 api/core/agent/entities.py delete mode 100644 api/core/app/advanced_chat/config_validator.py rename api/core/app/{advanced_chat => app_config}/__init__.py (100%) create mode 100644 api/core/app/app_config/base_app_config_manager.py rename api/core/app/{agent_chat => app_config/common}/__init__.py (100%) rename api/core/app/{chat => app_config/common/sensitive_word_avoidance}/__init__.py (100%) rename api/core/app/{validators/moderation.py => app_config/common/sensitive_word_avoidance/manager.py} (64%) rename api/core/app/{completion => app_config/easy_ui_based_app}/__init__.py (100%) rename api/core/app/{validators => app_config/easy_ui_based_app/agent}/__init__.py (100%) create mode 100644 api/core/app/app_config/easy_ui_based_app/agent/manager.py rename api/core/app/{workflow => app_config/easy_ui_based_app/dataset}/__init__.py (100%) rename api/core/app/{validators/dataset_retrieval.py => app_config/easy_ui_based_app/dataset/manager.py} (63%) create mode 100644 api/core/app/app_config/easy_ui_based_app/model_config/__init__.py create mode 100644 api/core/app/app_config/easy_ui_based_app/model_config/converter.py rename api/core/app/{validators/model_validator.py => app_config/easy_ui_based_app/model_config/manager.py} (73%) create mode 100644 api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py rename api/core/app/{validators/prompt.py => app_config/easy_ui_based_app/prompt_template/manager.py} (58%) create mode 100644 api/core/app/app_config/easy_ui_based_app/variables/__init__.py create mode 100644 api/core/app/app_config/easy_ui_based_app/variables/manager.py rename api/core/{entities/application_entities.py => app/app_config/entities.py} (61%) create mode 100644 api/core/app/app_config/features/__init__.py create mode 100644 api/core/app/app_config/features/file_upload/__init__.py rename api/core/app/{validators/file_upload.py => app_config/features/file_upload/manager.py} (59%) create mode 100644 api/core/app/app_config/features/more_like_this/__init__.py rename api/core/app/{validators/more_like_this.py => app_config/features/more_like_this/manager.py} (63%) create mode 100644 api/core/app/app_config/features/opening_statement/__init__.py rename api/core/app/{validators/opening_statement.py => app_config/features/opening_statement/manager.py} (66%) create mode 100644 api/core/app/app_config/features/retrieval_resource/__init__.py rename api/core/app/{validators/retriever_resource.py => app_config/features/retrieval_resource/manager.py} (68%) create mode 100644 api/core/app/app_config/features/speech_to_text/__init__.py rename api/core/app/{validators/speech_to_text.py => app_config/features/speech_to_text/manager.py} (63%) create mode 100644 api/core/app/app_config/features/suggested_questions_after_answer/__init__.py rename api/core/app/{validators/suggested_questions.py => app_config/features/suggested_questions_after_answer/manager.py} (57%) create mode 100644 api/core/app/app_config/features/text_to_speech/__init__.py rename api/core/app/{validators/text_to_speech.py => app_config/features/text_to_speech/manager.py} (56%) create mode 100644 api/core/app/app_config/workflow_ui_based_app/__init__.py create mode 100644 api/core/app/app_config/workflow_ui_based_app/variables/__init__.py create mode 100644 api/core/app/app_config/workflow_ui_based_app/variables/manager.py delete mode 100644 api/core/app/app_orchestration_config_converter.py create mode 100644 api/core/app/apps/__init__.py create mode 100644 api/core/app/apps/advanced_chat/__init__.py create mode 100644 api/core/app/apps/advanced_chat/app_config_manager.py create mode 100644 api/core/app/apps/agent_chat/__init__.py rename api/core/app/{agent_chat/config_validator.py => apps/agent_chat/app_config_manager.py} (51%) rename api/core/app/{ => apps}/agent_chat/app_runner.py (83%) rename api/core/app/{ => apps}/base_app_runner.py (93%) create mode 100644 api/core/app/apps/chat/__init__.py create mode 100644 api/core/app/apps/chat/app_config_manager.py rename api/core/app/{ => apps}/chat/app_runner.py (76%) create mode 100644 api/core/app/apps/completion/__init__.py create mode 100644 api/core/app/apps/completion/app_config_manager.py rename api/core/app/{ => apps}/completion/app_runner.py (74%) create mode 100644 api/core/app/apps/workflow/__init__.py create mode 100644 api/core/app/apps/workflow/app_config_manager.py delete mode 100644 api/core/app/chat/config_validator.py delete mode 100644 api/core/app/completion/config_validator.py create mode 100644 api/core/app/entities/__init__.py create mode 100644 api/core/app/entities/app_invoke_entities.py rename api/core/{ => app}/entities/queue_entities.py (100%) delete mode 100644 api/core/app/validators/external_data_fetch.py delete mode 100644 api/core/app/validators/user_input_form.py delete mode 100644 api/core/app/workflow/config_validator.py delete mode 100644 api/core/callback_handler/agent_loop_gather_callback_handler.py delete mode 100644 api/core/callback_handler/entity/agent_loop.py delete mode 100644 api/core/rag/retrieval/agent/agent_llm_callback.py diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index c7f3a598ca..4de4a6f3fe 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): file = request.files['file'] diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 0632c0439b..ed1522c0cd 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -22,7 +22,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value @@ -103,7 +103,7 @@ class ChatMessageApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() parser.add_argument('inputs', type=dict, required=True, location='json') @@ -168,7 +168,7 @@ class ChatMessageStopApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model, task_id): account = flask_login.current_user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b808d62eb0..33711076f8 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -112,7 +112,7 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -133,7 +133,7 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @marshal_with(conversation_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -227,7 +227,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index c384e878aa..111ec7d787 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -42,7 +42,7 @@ class ChatMessageListApi(Resource): @setup_required @login_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): @@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def get(self, app_model, message_id): message_id = str(message_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index e3a5112200..51fe53c0ec 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=AppMode.CHAT) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def get(self, app_model): account = current_user diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 22ea4bbac2..dd531974fa 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -22,7 +22,7 @@ from controllers.console.app.error import ( from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425f..fdb0eae24f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -24,7 +24,7 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index fd4ce831b3..5c488093fa 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -20,7 +20,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd94ec7646..785e2b8d6b 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -21,7 +21,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.app.app_queue_manager import AppQueueManager -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index e03bdd63bb..1acb92dbf1 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -21,7 +21,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1474c6a475..529240aecb 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,17 +5,15 @@ from datetime import datetime from mimetypes import guess_extension from typing import Optional, Union, cast +from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_queue_manager import AppQueueManager -from core.app.base_app_runner import AppRunner +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - AgentEntity, - AgentToolEntity, - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - InvokeFrom, - ModelConfigEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, + InvokeFrom, EasyUIBasedModelConfigEntity, ) from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory @@ -50,9 +48,9 @@ logger = logging.getLogger(__name__) class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, - application_generate_entity: ApplicationGenerateEntity, - app_orchestration_config: AppOrchestrationConfigEntity, - model_config: ModelConfigEntity, + application_generate_entity: EasyUIBasedAppGenerateEntity, + app_config: AgentChatAppConfig, + model_config: EasyUIBasedModelConfigEntity, config: AgentEntity, queue_manager: AppQueueManager, message: Message, @@ -66,7 +64,7 @@ class BaseAgentRunner(AppRunner): """ Agent runner :param tenant_id: tenant id - :param app_orchestration_config: app orchestration config + :param app_config: app generate entity :param model_config: model config :param config: dataset config :param queue_manager: queue manager @@ -78,7 +76,7 @@ class BaseAgentRunner(AppRunner): """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity - self.app_orchestration_config = app_orchestration_config + self.app_config = app_config self.model_config = model_config self.config = config self.queue_manager = queue_manager @@ -97,16 +95,16 @@ class BaseAgentRunner(AppRunner): # init dataset tools hit_callback = DatasetIndexToolCallbackHandler( queue_manager=queue_manager, - app_id=self.application_generate_entity.app_id, + app_id=self.app_config.app_id, message_id=message.id, user_id=user_id, invoke_from=self.application_generate_entity.invoke_from, ) self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( tenant_id=tenant_id, - dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], - retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, - return_resource=app_orchestration_config.show_retrieve_source, + dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], + retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, + return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback ) @@ -124,14 +122,15 @@ class BaseAgentRunner(AppRunner): else: self.stream_tool_call = False - def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: + def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ + -> EasyUIBasedAppGenerateEntity: """ - Repack app orchestration config + Repack app generate entity """ - if app_orchestration_config.prompt_template.simple_prompt_template is None: - app_orchestration_config.prompt_template.simple_prompt_template = '' + if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: + app_generate_entity.app_config.prompt_template.simple_prompt_template = '' - return app_orchestration_config + return app_generate_entity def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ @@ -351,7 +350,7 @@ class BaseAgentRunner(AppRunner): )) db.session.close() - + return result def create_agent_thought(self, message_id: str, message: str, @@ -462,7 +461,7 @@ class BaseAgentRunner(AppRunner): db.session.commit() db.session.close() - + def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ Transform tool message into agent thought diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 5650113f47..5b345f4da0 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -5,7 +5,7 @@ from typing import Literal, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.app_queue_manager import PublishFrom -from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit +from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -27,7 +27,7 @@ from core.tools.errors import ( from models.model import Conversation, Message -class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): +class CotAgentRunner(BaseAgentRunner): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] @@ -39,30 +39,33 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): """ Run Cot agent application """ - app_orchestration_config = self.app_orchestration_config - self._repack_app_orchestration_config(app_orchestration_config) + app_generate_entity = self.application_generate_entity + self._repack_app_generate_entity(app_generate_entity) agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) - if 'Observation' not in app_orchestration_config.model_config.stop: - if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: - app_orchestration_config.model_config.stop.append('Observation') + # check model mode + if 'Observation' not in app_generate_entity.model_config.stop: + if app_generate_entity.model_config.provider not in self._ignore_observation_providers: + app_generate_entity.model_config.stop.append('Observation') + + app_config = self.app_config # override inputs inputs = inputs or {} - instruction = self.app_orchestration_config.prompt_template.simple_prompt_template + instruction = app_config.prompt_template.simple_prompt_template instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 - max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -122,11 +125,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): # update prompt messages prompt_messages = self._organize_cot_prompt_messages( - mode=app_orchestration_config.model_config.mode, + mode=app_generate_entity.model_config.mode, prompt_messages=prompt_messages, tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, - agent_prompt_message=app_orchestration_config.agent.prompt, + agent_prompt_message=app_config.agent.prompt, instruction=instruction, input=query ) @@ -136,9 +139,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=[], - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=True, user=self.user_id, callbacks=[], @@ -550,7 +553,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): """ convert agent scratchpad list to str """ - next_iteration = self.app_orchestration_config.agent.prompt.next_iteration + next_iteration = self.app_config.agent.prompt.next_iteration result = '' for scratchpad in agent_scratchpad: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py new file mode 100644 index 0000000000..0fbfdc2636 --- /dev/null +++ b/api/core/agent/entities.py @@ -0,0 +1,61 @@ +from enum import Enum +from typing import Literal, Any, Union, Optional + +from pydantic import BaseModel + + +class AgentToolEntity(BaseModel): + """ + Agent Tool Entity. + """ + provider_type: Literal["builtin", "api"] + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] = {} + + +class AgentPromptEntity(BaseModel): + """ + Agent Prompt Entity. + """ + first_prompt: str + next_iteration: str + + +class AgentScratchpadUnit(BaseModel): + """ + Agent First Prompt Entity. + """ + + class Action(BaseModel): + """ + Action Entity. + """ + action_name: str + action_input: Union[dict, str] + + agent_response: Optional[str] = None + thought: Optional[str] = None + action_str: Optional[str] = None + observation: Optional[str] = None + action: Optional[Action] = None + + +class AgentEntity(BaseModel): + """ + Agent Entity. + """ + + class Strategy(Enum): + """ + Agent Strategy. + """ + CHAIN_OF_THOUGHT = 'chain-of-thought' + FUNCTION_CALLING = 'function-calling' + + provider: str + model: str + strategy: Strategy + prompt: Optional[AgentPromptEntity] = None + tools: list[AgentToolEntity] = None + max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9b238bf232..30e5cdd694 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -34,9 +34,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): """ Run FunctionCall agent application """ - app_orchestration_config = self.app_orchestration_config + app_generate_entity = self.application_generate_entity - prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' + app_config = self.app_config + + prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_messages = self.history_prompt_messages prompt_messages = self.organize_prompt_messages( prompt_template=prompt_template, @@ -47,7 +49,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} - for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + for tool in app_config.agent.tools if app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -67,7 +69,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_instances[dataset_tool.identity.name] = dataset_tool iteration_step = 1 - max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # continue to run until there is not any tool call function_call_state = True @@ -110,9 +112,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=app_generate_entity.model_config.parameters, tools=prompt_messages_tools, - stop=app_orchestration_config.model_config.stop, + stop=app_generate_entity.model_config.stop, stream=self.stream_tool_call, user=self.user_id, callbacks=[], diff --git a/api/core/app/advanced_chat/config_validator.py b/api/core/app/advanced_chat/config_validator.py deleted file mode 100644 index a20198ef4a..0000000000 --- a/api/core/app/advanced_chat/config_validator.py +++ /dev/null @@ -1,59 +0,0 @@ -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator - - -class AdvancedChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: - """ - Validate for advanced chat app model config - - :param tenant_id: tenant id - :param config: app model config args - :param only_structure_validate: if True, only structure validation will be performed - """ - related_config_keys = [] - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate - ) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/advanced_chat/__init__.py b/api/core/app/app_config/__init__.py similarity index 100% rename from api/core/app/advanced_chat/__init__.py rename to api/core/app/app_config/__init__.py diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py new file mode 100644 index 0000000000..b3c773203d --- /dev/null +++ b/api/core/app/app_config/base_app_config_manager.py @@ -0,0 +1,73 @@ +from typing import Union, Optional + +from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppModelConfig + + +class BaseAppConfigManager: + + @classmethod + def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: Union[AppModelConfig, dict], + config_dict: Optional[dict] = None) -> dict: + """ + Convert app model config to config dict + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + + return config_dict + + @classmethod + def convert_features(cls, config_dict: dict) -> AppAdditionalFeatures: + """ + Convert app config to app model config + + :param config_dict: app config + """ + config_dict = config_dict.copy() + + additional_features = AppAdditionalFeatures() + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( + config=config_dict + ) + + additional_features.file_upload = FileUploadConfigManager.convert( + config=config_dict + ) + + additional_features.opening_statement, additional_features.suggested_questions = \ + OpeningStatementConfigManager.convert( + config=config_dict + ) + + additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( + config=config_dict + ) + + additional_features.more_like_this = MoreLikeThisConfigManager.convert( + config=config_dict + ) + + additional_features.speech_to_text = SpeechToTextConfigManager.convert( + config=config_dict + ) + + additional_features.text_to_speech = TextToSpeechConfigManager.convert( + config=config_dict + ) + + return additional_features diff --git a/api/core/app/agent_chat/__init__.py b/api/core/app/app_config/common/__init__.py similarity index 100% rename from api/core/app/agent_chat/__init__.py rename to api/core/app/app_config/common/__init__.py diff --git a/api/core/app/chat/__init__.py b/api/core/app/app_config/common/sensitive_word_avoidance/__init__.py similarity index 100% rename from api/core/app/chat/__init__.py rename to api/core/app/app_config/common/sensitive_word_avoidance/__init__.py diff --git a/api/core/app/validators/moderation.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py similarity index 64% rename from api/core/app/validators/moderation.py rename to api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7a5dff55c9..3dccfa3cbe 100644 --- a/api/core/app/validators/moderation.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,11 +1,24 @@ -import logging +from typing import Optional +from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory -logger = logging.getLogger(__name__) +class SensitiveWordAvoidanceConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: + sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + if not sensitive_word_avoidance_dict: + return None + + if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: + return SensitiveWordAvoidanceEntity( + type=sensitive_word_avoidance_dict.get('type'), + config=sensitive_word_avoidance_dict.get('config'), + ) + else: + return None -class ModerationValidator: @classmethod def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ -> tuple[dict, list[str]]: diff --git a/api/core/app/completion/__init__.py b/api/core/app/app_config/easy_ui_based_app/__init__.py similarity index 100% rename from api/core/app/completion/__init__.py rename to api/core/app/app_config/easy_ui_based_app/__init__.py diff --git a/api/core/app/validators/__init__.py b/api/core/app/app_config/easy_ui_based_app/agent/__init__.py similarity index 100% rename from api/core/app/validators/__init__.py rename to api/core/app/app_config/easy_ui_based_app/agent/__init__.py diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py new file mode 100644 index 0000000000..b50b7f678c --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -0,0 +1,79 @@ +from typing import Optional + +from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity +from core.tools.prompt.template import REACT_PROMPT_TEMPLATES + + +class AgentConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[AgentEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + agent_strategy = agent_dict.get('strategy', 'cot') + + if agent_strategy == 'function_call': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + elif agent_strategy == 'cot' or agent_strategy == 'react': + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + else: + # old configs, try to detect default strategy + if config['model']['provider'] == 'openai': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + + agent_tools = [] + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) >= 4: + if "enabled" not in tool or not tool["enabled"]: + continue + + agent_tool_properties = { + 'provider_type': tool['provider_type'], + 'provider_id': tool['provider_id'], + 'tool_name': tool['tool_name'], + 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} + } + + agent_tools.append(AgentToolEntity(**agent_tool_properties)) + + if 'strategy' in config['agent_mode'] and \ + config['agent_mode']['strategy'] not in ['react_router', 'router']: + agent_prompt = agent_dict.get('prompt', None) or {} + # check model mode + model_mode = config.get('model', {}).get('mode', 'completion') + if model_mode == 'completion': + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['completion'][ + 'agent_scratchpad']), + ) + else: + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', + REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', + REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + ) + + return AgentEntity( + provider=config['model']['provider'], + model=config['model']['name'], + strategy=strategy, + prompt=agent_prompt_entity, + tools=agent_tools, + max_iteration=agent_dict.get('max_iteration', 5) + ) + + return None diff --git a/api/core/app/workflow/__init__.py b/api/core/app/app_config/easy_ui_based_app/dataset/__init__.py similarity index 100% rename from api/core/app/workflow/__init__.py rename to api/core/app/app_config/easy_ui_based_app/dataset/__init__.py diff --git a/api/core/app/validators/dataset_retrieval.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py similarity index 63% rename from api/core/app/validators/dataset_retrieval.py rename to api/core/app/app_config/easy_ui_based_app/dataset/manager.py index fb5b648320..4c08f62d27 100644 --- a/api/core/app/validators/dataset_retrieval.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,11 +1,94 @@ -import uuid +from typing import Optional +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.entities.agent_entities import PlanningStrategy from models.model import AppMode from services.dataset_service import DatasetService -class DatasetValidator: +class DatasetConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[DatasetEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + dataset_ids = [] + if 'datasets' in config.get('dataset_configs', {}): + datasets = config.get('dataset_configs', {}).get('datasets', { + 'strategy': 'router', + 'datasets': [] + }) + + for dataset in datasets.get('datasets', []): + keys = list(dataset.keys()) + if len(keys) == 0 or keys[0] != 'dataset': + continue + + dataset = dataset['dataset'] + + if 'enabled' not in dataset or not dataset['enabled']: + continue + + dataset_id = dataset.get('id', None) + if dataset_id: + dataset_ids.append(dataset_id) + + if 'agent_mode' in config and config['agent_mode'] \ + and 'enabled' in config['agent_mode'] \ + and config['agent_mode']['enabled']: + + agent_dict = config.get('agent_mode', {}) + + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) == 1: + # old standard + key = list(tool.keys())[0] + + if key != 'dataset': + continue + + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + continue + + dataset_id = tool_item['id'] + dataset_ids.append(dataset_id) + + if len(dataset_ids) == 0: + return None + + # dataset configs + dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'}) + query_variable = config.get('dataset_query_variable') + + if dataset_configs['retrieval_model'] == 'single': + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ) + ) + ) + else: + return DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ), + top_k=dataset_configs.get('top_k'), + score_threshold=dataset_configs.get('score_threshold'), + reranking_model=dataset_configs.get('reranking_model') + ) + ) + @classmethod def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py b/api/core/app/app_config/easy_ui_based_app/model_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py new file mode 100644 index 0000000000..05fcb10791 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -0,0 +1,104 @@ +from typing import cast + +from core.app.app_config.entities import EasyUIBasedAppConfig +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity + +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.provider_manager import ProviderManager + + +class EasyUIBasedModelConfigEntityConverter: + @classmethod + def convert(cls, app_config: EasyUIBasedAppConfig, + skip_check: bool = False) \ + -> EasyUIBasedModelConfigEntity: + """ + Convert app model config dict to entity. + :param app_config: app config + :param skip_check: skip check + :raises ProviderTokenNotInitError: provider token not init error + :return: app orchestration config entity + """ + model_config = app_config.model + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=app_config.tenant_id, + provider=model_config.provider, + model_type=ModelType.LLM + ) + + provider_name = provider_model_bundle.configuration.provider.provider + model_name = model_config.model + + model_type_instance = provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + # check model credentials + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_config.model + ) + + if model_credentials is None: + if not skip_check: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + else: + model_credentials = {} + + if not skip_check: + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, + model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = model_config.parameters + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.mode + if not model_mode: + mode_enum = model_type_instance.get_model_mode( + model=model_config.model, + credentials=model_credentials + ) + + model_mode = mode_enum.value + + model_schema = model_type_instance.get_model_schema( + model_config.model, + model_credentials + ) + + if not skip_check and not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return EasyUIBasedModelConfigEntity( + provider=model_config.provider, + model=model_config.model, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) diff --git a/api/core/app/validators/model_validator.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py similarity index 73% rename from api/core/app/validators/model_validator.py rename to api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 1d86fbaf04..5cca2bc1a7 100644 --- a/api/core/app/validators/model_validator.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,40 @@ - -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.app.app_config.entities import ModelConfigEntity +from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager -class ModelValidator: +class ModelConfigManager: + @classmethod + def convert(cls, config: dict) -> ModelConfigEntity: + """ + Convert model config to model config + + :param config: model config args + """ + # model config + model_config = config.get('model') + + if not model_config: + raise ValueError("model is required") + + completion_params = model_config.get('completion_params') + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = model_config.get('mode') + + return ModelConfigEntity( + provider=config['model']['provider'], + model=config['model']['name'], + mode=model_mode, + parameters=completion_params, + stop=stop, + ) + @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/prompt.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py similarity index 58% rename from api/core/app/validators/prompt.py rename to api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 288a523415..5629d0d09e 100644 --- a/api/core/app/validators/prompt.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,10 +1,61 @@ - -from core.entities.application_entities import PromptTemplateEntity +from core.app.app_config.entities import PromptTemplateEntity, \ + AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode from models.model import AppMode -class PromptValidator: +class PromptTemplateConfigManager: + @classmethod + def convert(cls, config: dict) -> PromptTemplateEntity: + if not config.get("prompt_type"): + raise ValueError("prompt_type is required") + + prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + simple_prompt_template = config.get("pre_prompt", "") + return PromptTemplateEntity( + prompt_type=prompt_type, + simple_prompt_template=simple_prompt_template + ) + else: + advanced_chat_prompt_template = None + chat_prompt_config = config.get("chat_prompt_config", {}) + if chat_prompt_config: + chat_prompt_messages = [] + for message in chat_prompt_config.get("prompt", []): + chat_prompt_messages.append({ + "text": message["text"], + "role": PromptMessageRole.value_of(message["role"]) + }) + + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( + messages=chat_prompt_messages + ) + + advanced_completion_prompt_template = None + completion_prompt_config = config.get("completion_prompt_config", {}) + if completion_prompt_config: + completion_prompt_template_params = { + 'prompt': completion_prompt_config['prompt']['text'], + } + + if 'conversation_histories_role' in completion_prompt_config: + completion_prompt_template_params['role_prefix'] = { + 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], + 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + } + + advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( + **completion_prompt_template_params + ) + + return PromptTemplateEntity( + prompt_type=prompt_type, + advanced_chat_prompt_template=advanced_chat_prompt_template, + advanced_completion_prompt_template=advanced_completion_prompt_template + ) + @classmethod def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: """ @@ -83,4 +134,4 @@ class PromptValidator: if not isinstance(config["post_prompt"], str): raise ValueError("post_prompt must be of string type") - return config \ No newline at end of file + return config diff --git a/api/core/app/app_config/easy_ui_based_app/variables/__init__.py b/api/core/app/app_config/easy_ui_based_app/variables/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py new file mode 100644 index 0000000000..ff962a5439 --- /dev/null +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -0,0 +1,184 @@ +import re +from typing import Tuple + +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity +from core.external_data_tool.factory import ExternalDataToolFactory + + +class BasicVariablesConfigManager: + @classmethod + def convert(cls, config: dict) -> Tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + """ + Convert model config to model config + + :param config: model config args + """ + external_data_variables = [] + variables = [] + + # old external_data_tools + external_data_tools = config.get('external_data_tools', []) + for external_data_tool in external_data_tools: + if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + continue + + external_data_variables.append( + ExternalDataVariableEntity( + variable=external_data_tool['variable'], + type=external_data_tool['type'], + config=external_data_tool['config'] + ) + ) + + # variables and external_data_tools + for variable in config.get('user_input_form', []): + typ = list(variable.keys())[0] + if typ == 'external_data_tool': + val = variable[typ] + external_data_variables.append( + ExternalDataVariableEntity( + variable=val['variable'], + type=val['type'], + config=val['config'] + ) + ) + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: + variables.append( + VariableEntity( + type=VariableEntity.Type.value_of(typ), + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + max_length=variable[typ].get('max_length'), + default=variable[typ].get('default'), + ) + ) + elif typ == VariableEntity.Type.SELECT.value: + variables.append( + VariableEntity( + type=VariableEntity.Type.SELECT, + variable=variable[typ].get('variable'), + description=variable[typ].get('description'), + label=variable[typ].get('label'), + required=variable[typ].get('required', False), + options=variable[typ].get('options'), + default=variable[typ].get('default'), + ) + ) + + return variables, external_data_variables + + @classmethod + def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param tenant_id: workspace id + :param config: app model config args + """ + related_config_keys = [] + config, current_related_config_keys = cls.validate_variables_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + return config, related_config_keys + + @classmethod + def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for user input form + + :param config: app model config args + """ + if not config.get("user_input_form"): + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + return config, ["user_input_form"] + + @classmethod + def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + """ + Validate and set defaults for external data fetch feature + + :param tenant_id: workspace id + :param config: app model config args + """ + if not config.get("external_data_tools"): + config["external_data_tools"] = [] + + if not isinstance(config["external_data_tools"], list): + raise ValueError("external_data_tools must be of list type") + + for tool in config["external_data_tools"]: + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + + if not tool["enabled"]: + continue + + if "type" not in tool or not tool["type"]: + raise ValueError("external_data_tools[].type is required") + + typ = tool["type"] + config = tool["config"] + + ExternalDataToolFactory.validate_config( + name=typ, + tenant_id=tenant_id, + config=config + ) + + return config, ["external_data_tools"] \ No newline at end of file diff --git a/api/core/entities/application_entities.py b/api/core/app/app_config/entities.py similarity index 61% rename from api/core/entities/application_entities.py rename to api/core/app/app_config/entities.py index f5ea4d1eb0..e155dc1c4d 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/app/app_config/entities.py @@ -1,12 +1,10 @@ from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Optional from pydantic import BaseModel -from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import AIModelEntity +from models.model import AppMode class ModelConfigEntity(BaseModel): @@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel): """ provider: str model: str - model_schema: Optional[AIModelEntity] = None - mode: str - provider_model_bundle: ProviderModelBundle - credentials: dict[str, Any] = {} + mode: Optional[str] = None parameters: dict[str, Any] = {} stop: list[str] = [] @@ -194,149 +189,53 @@ class FileUploadEntity(BaseModel): image_config: Optional[dict[str, Any]] = None -class AgentToolEntity(BaseModel): - """ - Agent Tool Entity. - """ - provider_type: Literal["builtin", "api"] - provider_id: str - tool_name: str - tool_parameters: dict[str, Any] = {} - - -class AgentPromptEntity(BaseModel): - """ - Agent Prompt Entity. - """ - first_prompt: str - next_iteration: str - - -class AgentScratchpadUnit(BaseModel): - """ - Agent First Prompt Entity. - """ - - class Action(BaseModel): - """ - Action Entity. - """ - action_name: str - action_input: Union[dict, str] - - agent_response: Optional[str] = None - thought: Optional[str] = None - action_str: Optional[str] = None - observation: Optional[str] = None - action: Optional[Action] = None - - -class AgentEntity(BaseModel): - """ - Agent Entity. - """ - - class Strategy(Enum): - """ - Agent Strategy. - """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' - - provider: str - model: str - strategy: Strategy - prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None - max_iteration: int = 5 - - -class AppOrchestrationConfigEntity(BaseModel): - """ - App Orchestration Config Entity. - """ - model_config: ModelConfigEntity - prompt_template: PromptTemplateEntity - variables: list[VariableEntity] = [] - external_data_variables: list[ExternalDataVariableEntity] = [] - agent: Optional[AgentEntity] = None - - # features - dataset: Optional[DatasetEntity] = None +class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileUploadEntity] = None opening_statement: Optional[str] = None + suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False text_to_speech: Optional[TextToSpeechEntity] = None + + +class AppConfig(BaseModel): + """ + Application Config Entity. + """ + tenant_id: str + app_id: str + app_mode: AppMode + additional_features: AppAdditionalFeatures + variables: list[VariableEntity] = [] sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None -class InvokeFrom(Enum): +class EasyUIBasedAppModelConfigFrom(Enum): """ - Invoke From. + App Model Config From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' - - @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid invoke from value {value}') - - def to_source(self) -> str: - """ - Get source of invoke from. - - :return: source - """ - if self == InvokeFrom.WEB_APP: - return 'web_app' - elif self == InvokeFrom.DEBUGGER: - return 'dev' - elif self == InvokeFrom.EXPLORE: - return 'explore_app' - elif self == InvokeFrom.SERVICE_API: - return 'api' - - return 'dev' + ARGS = 'args' + APP_LATEST_CONFIG = 'app-latest-config' + CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' -class ApplicationGenerateEntity(BaseModel): +class EasyUIBasedAppConfig(AppConfig): """ - Application Generate Entity. + Easy UI Based App Config Entity. """ - task_id: str - tenant_id: str - - app_id: str + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - # for save app_model_config_dict: dict - app_model_config_override: bool + model: ModelConfigEntity + prompt_template: PromptTemplateEntity + dataset: Optional[DatasetEntity] = None + external_data_variables: list[ExternalDataVariableEntity] = [] - # Converted from app_model_config to Entity object, or directly covered by external input - app_orchestration_config_entity: AppOrchestrationConfigEntity - conversation_id: Optional[str] = None - inputs: dict[str, str] - query: Optional[str] = None - files: list[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom - - # extra parameters, like: auto_generate_conversation_name - extras: dict[str, Any] = {} +class WorkflowUIBasedAppConfig(AppConfig): + """ + Workflow UI Based App Config Entity. + """ + workflow_id: str diff --git a/api/core/app/app_config/features/__init__.py b/api/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/app_config/features/file_upload/__init__.py b/api/core/app/app_config/features/file_upload/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/file_upload.py b/api/core/app/app_config/features/file_upload/manager.py similarity index 59% rename from api/core/app/validators/file_upload.py rename to api/core/app/app_config/features/file_upload/manager.py index 419465bd51..63830696ff 100644 --- a/api/core/app/validators/file_upload.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,6 +1,30 @@ +from typing import Optional + +from core.app.app_config.entities import FileUploadEntity -class FileUploadValidator: +class FileUploadConfigManager: + @classmethod + def convert(cls, config: dict) -> Optional[FileUploadEntity]: + """ + Convert model config to model config + + :param config: model config args + """ + file_upload_dict = config.get('file_upload') + if file_upload_dict: + if 'image' in file_upload_dict and file_upload_dict['image']: + if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + return FileUploadEntity( + image_config={ + 'number_limits': file_upload_dict['image']['number_limits'], + 'detail': file_upload_dict['image']['detail'], + 'transfer_methods': file_upload_dict['image']['transfer_methods'] + } + ) + + return None + @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/more_like_this/__init__.py b/api/core/app/app_config/features/more_like_this/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/more_like_this.py b/api/core/app/app_config/features/more_like_this/manager.py similarity index 63% rename from api/core/app/validators/more_like_this.py rename to api/core/app/app_config/features/more_like_this/manager.py index 1c1bac9de6..ec2a9a6796 100644 --- a/api/core/app/validators/more_like_this.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -1,6 +1,19 @@ +class MoreLikeThisConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + :param config: model config args + """ + more_like_this = False + more_like_this_dict = config.get('more_like_this') + if more_like_this_dict: + if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: + more_like_this = True + + return more_like_this -class MoreLikeThisValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/opening_statement/__init__.py b/api/core/app/app_config/features/opening_statement/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/opening_statement.py b/api/core/app/app_config/features/opening_statement/manager.py similarity index 66% rename from api/core/app/validators/opening_statement.py rename to api/core/app/app_config/features/opening_statement/manager.py index f919230e0d..6183c6e749 100644 --- a/api/core/app/validators/opening_statement.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,6 +1,22 @@ +from typing import Tuple -class OpeningStatementValidator: +class OpeningStatementConfigManager: + @classmethod + def convert(cls, config: dict) -> Tuple[str, list]: + """ + Convert model config to model config + + :param config: model config args + """ + # opening statement + opening_statement = config.get('opening_statement') + + # suggested questions + suggested_questions_list = config.get('suggested_questions') + + return opening_statement, suggested_questions_list + @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/retrieval_resource/__init__.py b/api/core/app/app_config/features/retrieval_resource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/retriever_resource.py b/api/core/app/app_config/features/retrieval_resource/manager.py similarity index 68% rename from api/core/app/validators/retriever_resource.py rename to api/core/app/app_config/features/retrieval_resource/manager.py index 32725c7432..0694cb954e 100644 --- a/api/core/app/validators/retriever_resource.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -1,6 +1,14 @@ +class RetrievalResourceConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + show_retrieve_source = False + retriever_resource_dict = config.get('retriever_resource') + if retriever_resource_dict: + if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: + show_retrieve_source = True + return show_retrieve_source -class RetrieverResourceValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/speech_to_text/__init__.py b/api/core/app/app_config/features/speech_to_text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/speech_to_text.py b/api/core/app/app_config/features/speech_to_text/manager.py similarity index 63% rename from api/core/app/validators/speech_to_text.py rename to api/core/app/app_config/features/speech_to_text/manager.py index 92a1b25ae6..b98699bfff 100644 --- a/api/core/app/validators/speech_to_text.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -1,6 +1,19 @@ +class SpeechToTextConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + :param config: model config args + """ + speech_to_text = False + speech_to_text_dict = config.get('speech_to_text') + if speech_to_text_dict: + if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: + speech_to_text = True + + return speech_to_text -class SpeechToTextValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py b/api/core/app/app_config/features/suggested_questions_after_answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/suggested_questions.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py similarity index 57% rename from api/core/app/validators/suggested_questions.py rename to api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 9161b31678..5aacd3b32d 100644 --- a/api/core/app/validators/suggested_questions.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -1,6 +1,19 @@ +class SuggestedQuestionsAfterAnswerConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + :param config: model config args + """ + suggested_questions_after_answer = False + suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + if suggested_questions_after_answer_dict: + if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: + suggested_questions_after_answer = True + + return suggested_questions_after_answer -class SuggestedQuestionsValidator: @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ @@ -16,7 +29,8 @@ class SuggestedQuestionsValidator: if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: + if "enabled" not in config["suggested_questions_after_answer"] or not \ + config["suggested_questions_after_answer"]["enabled"]: config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/__init__.py b/api/core/app/app_config/features/text_to_speech/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/validators/text_to_speech.py b/api/core/app/app_config/features/text_to_speech/manager.py similarity index 56% rename from api/core/app/validators/text_to_speech.py rename to api/core/app/app_config/features/text_to_speech/manager.py index 182a912d52..1ff31034ad 100644 --- a/api/core/app/validators/text_to_speech.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -1,6 +1,26 @@ +from core.app.app_config.entities import TextToSpeechEntity -class TextToSpeechValidator: +class TextToSpeechConfigManager: + @classmethod + def convert(cls, config: dict) -> bool: + """ + Convert model config to model config + + :param config: model config args + """ + text_to_speech = False + text_to_speech_dict = config.get('text_to_speech') + if text_to_speech_dict: + if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: + text_to_speech = TextToSpeechEntity( + enabled=text_to_speech_dict.get('enabled'), + voice=text_to_speech_dict.get('voice'), + language=text_to_speech_dict.get('language'), + ) + + return text_to_speech + @classmethod def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ diff --git a/api/core/app/app_config/workflow_ui_based_app/__init__.py b/api/core/app/app_config/workflow_ui_based_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py b/api/core/app/app_config/workflow_ui_based_app/variables/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py new file mode 100644 index 0000000000..4b117d87f8 --- /dev/null +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -0,0 +1,22 @@ +from core.app.app_config.entities import VariableEntity +from models.workflow import Workflow + + +class WorkflowVariablesConfigManager: + @classmethod + def convert(cls, workflow: Workflow) -> list[VariableEntity]: + """ + Convert workflow start variables to variables + + :param workflow: workflow instance + """ + variables = [] + + # find start node + user_input_form = workflow.user_input_form() + + # variables + for variable in user_input_form: + variables.append(VariableEntity(**variable)) + + return variables diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 86c8d2cfc7..98ebe2c87d 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -8,13 +8,18 @@ from typing import Any, Optional, Union, cast from flask import Flask, current_app from pydantic import ValidationError -from core.app.agent_chat.app_runner import AgentChatAppRunner -from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter +from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, EasyUIBasedAppConfig, VariableEntity +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom -from core.app.chat.app_runner import ChatAppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, InvokeFrom, ) from core.file.file_obj import FileObj @@ -23,24 +28,19 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile +from models.model import App, Conversation, EndUser, Message, MessageFile, AppMode, AppModelConfig logger = logging.getLogger(__name__) -class AppManager: - """ - This class is responsible for managing application - """ +class EasyUIBasedAppManager: - def generate(self, tenant_id: str, - app_id: str, - app_model_config_id: str, - app_model_config_dict: dict, - app_model_config_override: bool, + def generate(self, app_model: App, + app_model_config: AppModelConfig, user: Union[Account, EndUser], invoke_from: InvokeFrom, inputs: dict[str, str], + app_model_config_dict: Optional[dict] = None, query: Optional[str] = None, files: Optional[list[FileObj]] = None, conversation: Optional[Conversation] = None, @@ -50,14 +50,12 @@ class AppManager: """ Generate App response. - :param tenant_id: workspace ID - :param app_id: app ID - :param app_model_config_id: app model config id - :param app_model_config_dict: app model config dict - :param app_model_config_override: app model config override + :param app_model: App + :param app_model_config: app model config :param user: account or end user :param invoke_from: invoke from source :param inputs: inputs + :param app_model_config_dict: app model config dict :param query: query :param files: file obj list :param conversation: conversation @@ -67,20 +65,21 @@ class AppManager: # init task id task_id = str(uuid.uuid4()) - # init application generate entity - application_generate_entity = ApplicationGenerateEntity( - task_id=task_id, - tenant_id=tenant_id, - app_id=app_id, - app_model_config_id=app_model_config_id, + # convert to app config + app_config = self.convert_to_app_config( + app_model=app_model, + app_model_config=app_model_config, app_model_config_dict=app_model_config_dict, - app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict( - tenant_id=tenant_id, - app_model_config_dict=app_model_config_dict - ), - app_model_config_override=app_model_config_override, + conversation=conversation + ) + + # init application generate entity + application_generate_entity = EasyUIBasedAppGenerateEntity( + task_id=task_id, + app_config=app_config, + model_config=EasyUIBasedModelConfigEntityConverter.convert(app_config), conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else inputs, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), query=query.replace('\x00', '') if query else None, files=files if files else [], user_id=user.id, @@ -89,7 +88,7 @@ class AppManager: extras=extras ) - if not stream and application_generate_entity.app_orchestration_config_entity.agent: + if not stream and application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: raise ValueError("Agent app is not supported in blocking mode.") # init generate records @@ -128,8 +127,85 @@ class AppManager: stream=stream ) + def convert_to_app_config(self, app_model: App, + app_model_config: AppModelConfig, + app_model_config_dict: Optional[dict] = None, + conversation: Optional[Conversation] = None) -> EasyUIBasedAppConfig: + if app_model_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + elif app_mode == AppMode.CHAT: + app_config = ChatAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + elif app_mode == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.config_convert( + app_model=app_model, + config_from=config_from, + app_model_config=app_model_config, + config_dict=app_model_config_dict + ) + else: + raise ValueError("Invalid app mode") + + return app_config + + def _get_cleaned_inputs(self, user_inputs: dict, app_config: EasyUIBasedAppConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + variables = app_config.variables + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + def _generate_worker(self, flask_app: Flask, - application_generate_entity: ApplicationGenerateEntity, + application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, message_id: str) -> None: @@ -148,7 +224,7 @@ class AppManager: conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - if application_generate_entity.app_orchestration_config_entity.agent: + if application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: # agent app runner = AgentChatAppRunner() runner.run( @@ -157,8 +233,8 @@ class AppManager: conversation=conversation, message=message ) - else: - # basic app + elif application_generate_entity.app_config.app_mode == AppMode.CHAT: + # chatbot app runner = ChatAppRunner() runner.run( application_generate_entity=application_generate_entity, @@ -166,6 +242,16 @@ class AppManager: conversation=conversation, message=message ) + elif application_generate_entity.app_config.app_mode == AppMode.COMPLETION: + # completion app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message + ) + else: + raise ValueError("Invalid app mode") except ConversationTaskStoppedException: pass except InvokeAuthorizationError: @@ -184,7 +270,7 @@ class AppManager: finally: db.session.remove() - def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, + def _handle_response(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -217,24 +303,24 @@ class AppManager: finally: db.session.remove() - def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ + def _init_generate_records(self, application_generate_entity: EasyUIBasedAppGenerateEntity) \ -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity :return: """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - - model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance + model_type_instance = application_generate_entity.model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) model_schema = model_type_instance.get_model_schema( - model=app_orchestration_config_entity.model_config.model, - credentials=app_orchestration_config_entity.model_config.credentials + model=application_generate_entity.model_config.model, + credentials=application_generate_entity.model_config.credentials ) + app_config = application_generate_entity.app_config + app_record = (db.session.query(App) - .filter(App.id == application_generate_entity.app_id).first()) + .filter(App.id == app_config.app_id).first()) app_mode = app_record.mode @@ -249,8 +335,8 @@ class AppManager: account_id = application_generate_entity.user_id override_model_configs = None - if application_generate_entity.app_model_config_override: - override_model_configs = application_generate_entity.app_model_config_dict + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS: + override_model_configs = app_config.app_model_config_dict introduction = '' if app_mode == 'chat': @@ -260,9 +346,9 @@ class AppManager: if not application_generate_entity.conversation_id: conversation = Conversation( app_id=app_record.id, - app_model_config_id=application_generate_entity.app_model_config_id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, + app_model_config_id=app_config.app_model_config_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_mode, name='New conversation', @@ -291,8 +377,8 @@ class AppManager: message = Message( app_id=app_record.id, - model_provider=app_orchestration_config_entity.model_config.provider, - model_id=app_orchestration_config_entity.model_config.model, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, conversation_id=conversation.id, inputs=application_generate_entity.inputs, @@ -311,7 +397,7 @@ class AppManager: from_source=from_source, from_end_user_id=end_user_id, from_account_id=account_id, - agent_based=app_orchestration_config_entity.agent is not None + agent_based=app_config.app_mode == AppMode.AGENT_CHAT, ) db.session.add(message) @@ -333,14 +419,14 @@ class AppManager: return conversation, message - def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: + def _get_conversation_introduction(self, application_generate_entity: EasyUIBasedAppGenerateEntity) -> str: """ Get conversation introduction :param application_generate_entity: application generate entity :return: conversation introduction """ - app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity - introduction = app_orchestration_config_entity.opening_statement + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement if introduction: try: diff --git a/api/core/app/app_orchestration_config_converter.py b/api/core/app/app_orchestration_config_converter.py deleted file mode 100644 index 1d429ee6d9..0000000000 --- a/api/core/app/app_orchestration_config_converter.py +++ /dev/null @@ -1,421 +0,0 @@ -from typing import cast - -from core.entities.application_entities import ( - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - AgentEntity, - AgentPromptEntity, - AgentToolEntity, - AppOrchestrationConfigEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - ModelConfigEntity, - PromptTemplateEntity, - SensitiveWordAvoidanceEntity, - TextToSpeechEntity, - VariableEntity, -) -from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.provider_manager import ProviderManager -from core.tools.prompt.template import REACT_PROMPT_TEMPLATES - - -class AppOrchestrationConfigConverter: - @classmethod - def convert_from_app_model_config_dict(cls, tenant_id: str, - app_model_config_dict: dict, - skip_check: bool = False) \ - -> AppOrchestrationConfigEntity: - """ - Convert app model config dict to entity. - :param tenant_id: tenant ID - :param app_model_config_dict: app model config dict - :param skip_check: skip check - :raises ProviderTokenNotInitError: provider token not init error - :return: app orchestration config entity - """ - properties = {} - - copy_app_model_config_dict = app_model_config_dict.copy() - - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=copy_app_model_config_dict['model']['provider'], - model_type=ModelType.LLM - ) - - provider_name = provider_model_bundle.configuration.provider.provider - model_name = copy_app_model_config_dict['model']['name'] - - model_type_instance = provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # check model credentials - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=copy_app_model_config_dict['model']['name'] - ) - - if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} - - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=copy_app_model_config_dict['model']['name'], - model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = copy_app_model_config_dict['model']['name'] - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = copy_app_model_config_dict['model'].get('completion_params') - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = copy_app_model_config_dict['model'].get('mode') - if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=copy_app_model_config_dict['model']['name'], - credentials=model_credentials - ) - - model_mode = mode_enum.value - - model_schema = model_type_instance.get_model_schema( - copy_app_model_config_dict['model']['name'], - model_credentials - ) - - if not skip_check and not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - properties['model_config'] = ModelConfigEntity( - provider=copy_app_model_config_dict['model']['provider'], - model=copy_app_model_config_dict['model']['name'], - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - # prompt template - prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type']) - if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: - simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "") - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) - else: - advanced_chat_prompt_template = None - chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {}) - if chat_prompt_config: - chat_prompt_messages = [] - for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) - - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) - - advanced_completion_prompt_template = None - completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {}) - if completion_prompt_config: - completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], - } - - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] - } - - advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( - **completion_prompt_template_params - ) - - properties['prompt_template'] = PromptTemplateEntity( - prompt_type=prompt_type, - advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template - ) - - # external data variables - properties['external_data_variables'] = [] - - # old external_data_tools - external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) - for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: - continue - - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] - ) - ) - - properties['variables'] = [] - - # variables and external_data_tools - for variable in copy_app_model_config_dict.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - properties['external_data_variables'].append( - ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] - ) - ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, - ]: - properties['variables'].append( - VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - properties['variables'].append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), - ) - ) - - # show retrieve source - show_retrieve_source = False - retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource') - if retriever_resource_dict: - if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: - show_retrieve_source = True - - properties['show_retrieve_source'] = show_retrieve_source - - dataset_ids = [] - if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): - datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) - - for dataset in datasets.get('datasets', []): - keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': - continue - dataset = dataset['dataset'] - - if 'enabled' not in dataset or not dataset['enabled']: - continue - - dataset_id = dataset.get('id', None) - if dataset_id: - dataset_ids.append(dataset_id) - - if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ - and 'enabled' in copy_app_model_config_dict['agent_mode'] \ - and copy_app_model_config_dict['agent_mode']['enabled']: - - agent_dict = copy_app_model_config_dict.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if copy_app_model_config_dict['model']['provider'] == 'openai': - strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - - agent_tools = [] - for tool in agent_dict.get('tools', []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: - continue - - agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} - } - - agent_tools.append(AgentToolEntity(**agent_tool_properties)) - elif len(keys) == 1: - # old standard - key = list(tool.keys())[0] - - if key != 'dataset': - continue - - tool_item = tool[key] - - if "enabled" not in tool_item or not tool_item["enabled"]: - continue - - dataset_id = tool_item['id'] - dataset_ids.append(dataset_id) - - if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ - copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} - # check model mode - model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), - ) - else: - agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), - ) - - properties['agent'] = AgentEntity( - provider=properties['model_config'].provider, - model=properties['model_config'].model, - strategy=strategy, - prompt=agent_prompt_entity, - tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) - ) - - if len(dataset_ids) > 0: - # dataset configs - dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) - query_variable = copy_app_model_config_dict.get('dataset_query_variable') - - if dataset_configs['retrieval_model'] == 'single': - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) - ) - else: - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - top_k=dataset_configs.get('top_k'), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') - ) - ) - - # file upload - file_upload_dict = copy_app_model_config_dict.get('file_upload') - if file_upload_dict: - if 'image' in file_upload_dict and file_upload_dict['image']: - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - properties['file_upload'] = FileUploadEntity( - image_config={ - 'number_limits': file_upload_dict['image']['number_limits'], - 'detail': file_upload_dict['image']['detail'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] - } - ) - - # opening statement - properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement') - - # suggested questions after answer - suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer') - if suggested_questions_after_answer_dict: - if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: - properties['suggested_questions_after_answer'] = True - - # more like this - more_like_this_dict = copy_app_model_config_dict.get('more_like_this') - if more_like_this_dict: - if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: - properties['more_like_this'] = True - - # speech to text - speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text') - if speech_to_text_dict: - if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: - properties['speech_to_text'] = True - - # text to speech - text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech') - if text_to_speech_dict: - if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: - properties['text_to_speech'] = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), - ) - - # sensitive word avoidance - sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') - if sensitive_word_avoidance_dict: - if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: - properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), - ) - - return AppOrchestrationConfigEntity(**properties) diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py index c09cae3245..4bd491269c 100644 --- a/api/core/app/app_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -6,8 +6,8 @@ from typing import Any from sqlalchemy.orm import DeclarativeMeta -from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, diff --git a/api/core/app/apps/__init__.py b/api/core/app/apps/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/advanced_chat/__init__.py b/api/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py new file mode 100644 index 0000000000..ab7857c4ad --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -0,0 +1,94 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import AppMode, App +from models.workflow import Workflow + + +class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): + """ + Advanced Chatbot App Config Entity. + """ + pass + + +class AdvancedChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: + features_dict = workflow.features_dict + + app_config = AdvancedChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for advanced chat app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: if True, only structure validation will be performed + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config + diff --git a/api/core/app/apps/agent_chat/__init__.py b/api/core/app/apps/agent_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/agent_chat/config_validator.py b/api/core/app/apps/agent_chat/app_config_manager.py similarity index 51% rename from api/core/app/agent_chat/config_validator.py rename to api/core/app/apps/agent_chat/app_config_manager.py index 82bc40bd9b..96dac4bd01 100644 --- a/api/core/app/agent_chat/config_validator.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,24 +1,82 @@ import uuid +from typing import Optional -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator +from core.agent.entities import AgentEntity +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, DatasetEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from models.model import AppMode, App, AppModelConfig OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] -class AgentChatAppConfigValidator: +class AgentChatAppConfig(EasyUIBasedAppConfig): + """ + Agent Chatbot App Config Entity. + """ + agent: Optional[AgentEntity] = None + + +class AgentChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> AgentChatAppConfig: + """ + Convert app model config to agent chat app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = AgentChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + agent=AgentConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + @classmethod def config_validate(cls, tenant_id: str, config: dict) -> dict: """ @@ -32,23 +90,19 @@ class AgentChatAppConfigValidator: related_config_keys = [] # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) related_config_keys.extend(current_related_config_keys) # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) related_config_keys.extend(current_related_config_keys) # agent_mode @@ -56,27 +110,29 @@ class AgentChatAppConfigValidator: related_config_keys.extend(current_related_config_keys) # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) related_config_keys.extend(current_related_config_keys) # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -143,7 +199,7 @@ class AgentChatAppConfigValidator: except ValueError: raise ValueError("id in dataset must be of UUID type") - if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): + if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]): raise ValueError("Dataset ID does not exist, please check your permission.") else: # latest style, use key-value pair diff --git a/api/core/app/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py similarity index 83% rename from api/core/app/agent_chat/app_runner.py rename to api/core/app/apps/agent_chat/app_runner.py index 38789348ad..2f1de8f108 100644 --- a/api/core/app/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -2,10 +2,12 @@ import logging from typing import cast from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.base_app_runner import AppRunner -from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -24,7 +26,7 @@ class AgentChatAppRunner(AppRunner): """ Agent Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -36,12 +38,13 @@ class AgentChatAppRunner(AppRunner): :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(AgentChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -53,8 +56,8 @@ class AgentChatAppRunner(AppRunner): # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -64,22 +67,22 @@ class AgentChatAppRunner(AppRunner): if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( conversation=conversation, model_instance=model_instance ) - + # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) # memory(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -91,15 +94,15 @@ class AgentChatAppRunner(AppRunner): # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -123,7 +126,7 @@ class AgentChatAppRunner(AppRunner): ) self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -131,7 +134,7 @@ class AgentChatAppRunner(AppRunner): return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -146,8 +149,8 @@ class AgentChatAppRunner(AppRunner): # memory(optional), external data, dataset context(optional) prompt_messages, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -164,25 +167,25 @@ class AgentChatAppRunner(AppRunner): if hosting_moderation_result: return - agent_entity = app_orchestration_config.agent + agent_entity = app_config.agent # load tool variables tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, user_id=application_generate_entity.user_id, - tenant_id=application_generate_entity.tenant_id) + tenant_id=app_config.tenant_id) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) # init model instance model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -203,10 +206,10 @@ class AgentChatAppRunner(AppRunner): # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = CotAgentRunner( - tenant_id=application_generate_entity.tenant_id, + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -225,10 +228,10 @@ class AgentChatAppRunner(AppRunner): ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: assistant_fc_runner = FunctionCallAgentRunner( - tenant_id=application_generate_entity.tenant_id, + tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, + app_config=app_config, + model_config=application_generate_entity.model_config, config=agent_entity, queue_manager=queue_manager, message=message, @@ -289,7 +292,7 @@ class AgentChatAppRunner(AppRunner): 'pool': db_variables.variables }) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, + def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, message: Message) -> LLMUsage: """ Get usage of all agent thoughts diff --git a/api/core/app/base_app_runner.py b/api/core/app/apps/base_app_runner.py similarity index 93% rename from api/core/app/base_app_runner.py rename to api/core/app/apps/base_app_runner.py index 2760d04180..93f819af08 100644 --- a/api/core/app/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -2,16 +2,13 @@ import time from collections.abc import Generator from typing import Optional, Union, cast +from core.app.app_config.entities import PromptTemplateEntity, ExternalDataVariableEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature -from core.entities.application_entities import ( - ApplicationGenerateEntity, - AppOrchestrationConfigEntity, - ExternalDataVariableEntity, - InvokeFrom, - ModelConfigEntity, - PromptTemplateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, + InvokeFrom, EasyUIBasedModelConfigEntity, ) from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileObj @@ -29,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -85,7 +82,7 @@ class AppRunner: return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, + def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance @@ -121,7 +118,7 @@ class AppRunner: model_config.parameters[parameter_rule.name] = max_tokens def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -170,7 +167,7 @@ class AppRunner: return prompt_messages, stop def direct_output(self, queue_manager: AppQueueManager, - app_orchestration_config: AppOrchestrationConfigEntity, + app_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list, text: str, stream: bool, @@ -178,7 +175,7 @@ class AppRunner: """ Direct output :param queue_manager: application queue manager - :param app_orchestration_config: app orchestration config + :param app_generate_entity: app generate entity :param prompt_messages: prompt messages :param text: text :param stream: stream @@ -189,7 +186,7 @@ class AppRunner: index = 0 for token in text: queue_manager.publish_chunk_message(LLMResultChunk( - model=app_orchestration_config.model_config.model, + model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, @@ -201,7 +198,7 @@ class AppRunner: queue_manager.publish_message_end( llm_result=LLMResult( - model=app_orchestration_config.model_config.model, + model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage if usage else LLMUsage.empty_usage() @@ -294,14 +291,14 @@ class AppRunner: def moderation_for_inputs(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_generate_entity: EasyUIBasedAppGenerateEntity, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_generate_entity: app generate entity :param inputs: inputs :param query: query :return: @@ -310,12 +307,12 @@ class AppRunner: return moderation_feature.check( app_id=app_id, tenant_id=tenant_id, - app_orchestration_config_entity=app_orchestration_config_entity, + app_config=app_generate_entity.app_config, inputs=inputs, query=query, ) - def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: """ @@ -334,7 +331,7 @@ class AppRunner: if moderation_result: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=application_generate_entity.app_orchestration_config_entity, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text="I apologize for any confusion, " \ "but I'm an AI assistant to be helpful, harmless, and honest.", diff --git a/api/core/app/apps/chat/__init__.py b/api/core/app/apps/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py new file mode 100644 index 0000000000..62b2aaae5a --- /dev/null +++ b/api/core/app/apps/chat/app_config_manager.py @@ -0,0 +1,135 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import \ + SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppMode, App, AppModelConfig + + +class ChatAppConfig(EasyUIBasedAppConfig): + """ + Chatbot App Config Entity. + """ + pass + + +class ChatAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> ChatAppConfig: + """ + Convert app model config to chat app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = ChatAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for chat app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.CHAT + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # opening_statement + config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # suggested_questions_after_answer + config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + config) + related_config_keys.extend(current_related_config_keys) + + # speech_to_text + config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # return retriever resource + config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py similarity index 76% rename from api/core/app/chat/app_runner.py rename to api/core/app/apps/chat/app_runner.py index 4c8018572e..403a2d4476 100644 --- a/api/core/app/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,10 +1,12 @@ import logging +from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.base_app_runner import AppRunner +from core.app.apps.chat.app_config_manager import ChatAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -21,7 +23,7 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -33,12 +35,13 @@ class ChatAppRunner(AppRunner): :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(ChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -50,8 +53,8 @@ class ChatAppRunner(AppRunner): # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -61,8 +64,8 @@ class ChatAppRunner(AppRunner): if application_generate_entity.conversation_id: # get memory of conversation (read-only) model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) memory = TokenBufferMemory( @@ -75,8 +78,8 @@ class ChatAppRunner(AppRunner): # memory(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -88,15 +91,15 @@ class ChatAppRunner(AppRunner): # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -120,7 +123,7 @@ class ChatAppRunner(AppRunner): ) self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, stream=application_generate_entity.stream @@ -128,7 +131,7 @@ class ChatAppRunner(AppRunner): return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -140,7 +143,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None - if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, app_record.id, @@ -152,11 +155,11 @@ class ChatAppRunner(AppRunner): dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - model_config=app_orchestration_config.model_config, - config=app_orchestration_config.dataset, + model_config=application_generate_entity.model_config, + config=app_config.dataset, query=query, invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_orchestration_config.show_retrieve_source, + show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, memory=memory ) @@ -166,8 +169,8 @@ class ChatAppRunner(AppRunner): # memory(optional), external data, dataset context(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -186,22 +189,22 @@ class ChatAppRunner(AppRunner): return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=app_orchestration_config.model_config, + self.recale_llm_max_tokens( + model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) # Invoke model model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) db.session.close() invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=application_generate_entity.model_config.parameters, stop=stop, stream=application_generate_entity.stream, user=application_generate_entity.user_id, diff --git a/api/core/app/apps/completion/__init__.py b/api/core/app/apps/completion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py new file mode 100644 index 0000000000..b920f369b5 --- /dev/null +++ b/api/core/app/apps/completion/app_config_manager.py @@ -0,0 +1,118 @@ +from typing import Optional + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager +from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from models.model import AppMode, App, AppModelConfig + + +class CompletionAppConfig(EasyUIBasedAppConfig): + """ + Completion App Config Entity. + """ + pass + + +class CompletionAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, + config_from: EasyUIBasedAppModelConfigFrom, + app_model_config: AppModelConfig, + config_dict: Optional[dict] = None) -> CompletionAppConfig: + """ + Convert app model config to completion app config + :param app_model: app model + :param config_from: app model config from + :param app_model_config: app model config + :param config_dict: app model config dict + :return: + """ + config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + + app_config = CompletionAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + app_model_config_from=config_from, + app_model_config_id=app_model_config.id, + app_model_config_dict=config_dict, + model=ModelConfigManager.convert( + config=config_dict + ), + prompt_template=PromptTemplateConfigManager.convert( + config=config_dict + ), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=config_dict + ), + dataset=DatasetConfigManager.convert( + config=config_dict + ), + additional_features=cls.convert_features(config_dict) + ) + + app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( + config=config_dict + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict) -> dict: + """ + Validate for completion app model config + + :param tenant_id: tenant id + :param config: app model config args + """ + app_mode = AppMode.COMPLETION + + related_config_keys = [] + + # model + config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # user_input_form + config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config) + related_config_keys.extend(current_related_config_keys) + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # prompt + config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config) + related_config_keys.extend(current_related_config_keys) + + # dataset_query_variable + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, + config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # more_like_this + config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, + config) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py similarity index 74% rename from api/core/app/completion/app_runner.py rename to api/core/app/apps/completion/app_runner.py index ab2f40ad9a..8f0f191d45 100644 --- a/api/core/app/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,10 +1,12 @@ import logging +from typing import cast from core.app.app_queue_manager import AppQueueManager -from core.app.base_app_runner import AppRunner +from core.app.apps.completion.app_config_manager import CompletionAppConfig +from core.app.apps.base_app_runner import AppRunner from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import ( - ApplicationGenerateEntity, +from core.app.entities.app_invoke_entities import ( + EasyUIBasedAppGenerateEntity, ) from core.model_manager import ModelInstance from core.moderation.base import ModerationException @@ -20,7 +22,7 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: ApplicationGenerateEntity, + def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, message: Message) -> None: """ @@ -30,12 +32,13 @@ class CompletionAppRunner(AppRunner): :param message: message :return: """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + app_config = application_generate_entity.app_config + app_config = cast(CompletionAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files @@ -47,8 +50,8 @@ class CompletionAppRunner(AppRunner): # Not Include: memory, external data, dataset context self.get_pre_calculate_rest_tokens( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -58,8 +61,8 @@ class CompletionAppRunner(AppRunner): # Include: prompt template, inputs, query(optional), files(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query @@ -70,15 +73,15 @@ class CompletionAppRunner(AppRunner): # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config_entity=app_orchestration_config, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: self.direct_output( queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, + app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), stream=application_generate_entity.stream @@ -86,7 +89,7 @@ class CompletionAppRunner(AppRunner): return # fill in variable inputs from external data tools if exists - external_data_tools = app_orchestration_config.external_data_variables + external_data_tools = app_config.external_data_variables if external_data_tools: inputs = self.fill_in_inputs_from_external_data_tools( tenant_id=app_record.tenant_id, @@ -98,7 +101,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None - if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: + if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, app_record.id, @@ -107,18 +110,18 @@ class CompletionAppRunner(AppRunner): application_generate_entity.invoke_from ) - dataset_config = app_orchestration_config.dataset + dataset_config = app_config.dataset if dataset_config and dataset_config.retrieve_config.query_variable: query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( tenant_id=app_record.tenant_id, - model_config=app_orchestration_config.model_config, + model_config=application_generate_entity.model_config, config=dataset_config, query=query, invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_orchestration_config.show_retrieve_source, + show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback ) @@ -127,8 +130,8 @@ class CompletionAppRunner(AppRunner): # memory(optional), external data, dataset context(optional) prompt_messages, stop = self.organize_prompt_messages( app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, + model_config=application_generate_entity.model_config, + prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, query=query, @@ -147,19 +150,19 @@ class CompletionAppRunner(AppRunner): # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit self.recale_llm_max_tokens( - model_config=app_orchestration_config.model_config, + model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) # Invoke model model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model + provider_model_bundle=application_generate_entity.model_config.provider_model_bundle, + model=application_generate_entity.model_config.model ) invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, + model_parameters=application_generate_entity.model_config.parameters, stop=stop, stream=application_generate_entity.stream, user=application_generate_entity.user_id, diff --git a/api/core/app/apps/workflow/__init__.py b/api/core/app/apps/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py new file mode 100644 index 0000000000..35da72b63e --- /dev/null +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -0,0 +1,71 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.model import AppMode, App +from models.workflow import Workflow + + +class WorkflowAppConfig(WorkflowUIBasedAppConfig): + """ + Workflow App Config Entity. + """ + pass + + +class WorkflowAppConfigManager(BaseAppConfigManager): + @classmethod + def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + features_dict = workflow.features_dict + + app_config = WorkflowAppConfig( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + app_mode=AppMode.value_of(app_model.mode), + workflow_id=workflow.id, + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( + config=features_dict + ), + variables=WorkflowVariablesConfigManager.convert( + workflow=workflow + ), + additional_features=cls.convert_features(features_dict) + ) + + return app_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for workflow app model config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, + config=config, + only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/chat/config_validator.py b/api/core/app/chat/config_validator.py deleted file mode 100644 index adb8408e28..0000000000 --- a/api/core/app/chat/config_validator.py +++ /dev/null @@ -1,82 +0,0 @@ -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.opening_statement import OpeningStatementValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.retriever_resource import RetrieverResourceValidator -from core.app.validators.speech_to_text import SpeechToTextValidator -from core.app.validators.suggested_questions import SuggestedQuestionsValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class ChatAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for chat app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.CHAT - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # dataset_query_variable - config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/completion/config_validator.py b/api/core/app/completion/config_validator.py deleted file mode 100644 index 7cc35efd64..0000000000 --- a/api/core/app/completion/config_validator.py +++ /dev/null @@ -1,67 +0,0 @@ -from core.app.validators.dataset_retrieval import DatasetValidator -from core.app.validators.external_data_fetch import ExternalDataFetchValidator -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.model_validator import ModelValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.more_like_this import MoreLikeThisValidator -from core.app.validators.prompt import PromptValidator -from core.app.validators.text_to_speech import TextToSpeechValidator -from core.app.validators.user_input_form import UserInputFormValidator -from models.model import AppMode - - -class CompletionAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: - """ - Validate for completion app model config - - :param tenant_id: tenant id - :param config: app model config args - """ - app_mode = AppMode.COMPLETION - - related_config_keys = [] - - # model - config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # user_input_form - config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # external data tools validation - config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # prompt - config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # dataset_query_variable - config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # more_like_this - config, current_related_config_keys = MoreLikeThisValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py new file mode 100644 index 0000000000..fae9044fc3 --- /dev/null +++ b/api/core/app/entities/app_invoke_entities.py @@ -0,0 +1,111 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.entities.provider_configuration import ProviderModelBundle +from core.file.file_obj import FileObj +from core.model_runtime.entities.model_entities import AIModelEntity + + +class InvokeFrom(Enum): + """ + Invoke From. + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + EXPLORE = 'explore' + DEBUGGER = 'debugger' + + @classmethod + def value_of(cls, value: str) -> 'InvokeFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid invoke from value {value}') + + def to_source(self) -> str: + """ + Get source of invoke from. + + :return: source + """ + if self == InvokeFrom.WEB_APP: + return 'web_app' + elif self == InvokeFrom.DEBUGGER: + return 'dev' + elif self == InvokeFrom.EXPLORE: + return 'explore_app' + elif self == InvokeFrom.SERVICE_API: + return 'api' + + return 'dev' + + +class EasyUIBasedModelConfigEntity(BaseModel): + """ + Model Config Entity. + """ + provider: str + model: str + model_schema: AIModelEntity + mode: str + provider_model_bundle: ProviderModelBundle + credentials: dict[str, Any] = {} + parameters: dict[str, Any] = {} + stop: list[str] = [] + + +class EasyUIBasedAppGenerateEntity(BaseModel): + """ + EasyUI Based Application Generate Entity. + """ + task_id: str + + # app config + app_config: EasyUIBasedAppConfig + model_config: EasyUIBasedModelConfigEntity + + conversation_id: Optional[str] = None + inputs: dict[str, str] + query: Optional[str] = None + files: list[FileObj] = [] + user_id: str + # extras + stream: bool + invoke_from: InvokeFrom + + # extra parameters, like: auto_generate_conversation_name + extras: dict[str, Any] = {} + + +class WorkflowUIBasedAppGenerateEntity(BaseModel): + """ + Workflow UI Based Application Generate Entity. + """ + task_id: str + + # app config + app_config: WorkflowUIBasedAppConfig + + inputs: dict[str, str] + files: list[FileObj] = [] + user_id: str + # extras + stream: bool + invoke_from: InvokeFrom + + # extra parameters + extras: dict[str, Any] = {} + + +class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity): + conversation_id: Optional[str] = None + query: str diff --git a/api/core/entities/queue_entities.py b/api/core/app/entities/queue_entities.py similarity index 100% rename from api/core/entities/queue_entities.py rename to api/core/app/entities/queue_entities.py diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index fd516e465f..19ff94de5e 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d8ae7adcac..ec316248a2 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) class HostingModerationFeature: - def check(self, application_generate_entity: ApplicationGenerateEntity, + def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]) -> bool: """ Check hosting moderation @@ -16,8 +16,7 @@ class HostingModerationFeature: :param prompt_messages: prompt messages :return: """ - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - model_config = app_orchestration_config.model_config + model_config = application_generate_entity.model_config text = "" for prompt_message in prompt_messages: diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index dc6ea2db79..359369ef59 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,8 +7,8 @@ from typing import Optional, Union, cast from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom -from core.entities.queue_entities import ( +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, @@ -58,7 +58,7 @@ class GenerateTaskPipeline: GenerateTaskPipeline is a class that generate stream output and state management for Application. """ - def __init__(self, application_generate_entity: ApplicationGenerateEntity, + def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -75,7 +75,7 @@ class GenerateTaskPipeline: self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, + model=self._application_generate_entity.model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -127,7 +127,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._application_generate_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -210,7 +210,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.app_orchestration_config_entity.model_config + model_config = self._application_generate_entity.model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -569,7 +569,7 @@ class GenerateTaskPipeline: :return: """ prompts = [] - if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': + if self._application_generate_entity.model_config.mode == 'chat': for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' @@ -638,13 +638,13 @@ class GenerateTaskPipeline: Init output moderation. :return: """ - app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity - sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance if sensitive_word_avoidance: return OutputModeration( - tenant_id=self._application_generate_entity.tenant_id, - app_id=self._application_generate_entity.app_id, + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, rule=ModerationRule( type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config diff --git a/api/core/app/validators/external_data_fetch.py b/api/core/app/validators/external_data_fetch.py deleted file mode 100644 index 5910aa17e7..0000000000 --- a/api/core/app/validators/external_data_fetch.py +++ /dev/null @@ -1,39 +0,0 @@ - -from core.external_data_tool.factory import ExternalDataToolFactory - - -class ExternalDataFetchValidator: - @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for external data fetch feature - - :param tenant_id: workspace id - :param config: app model config args - """ - if not config.get("external_data_tools"): - config["external_data_tools"] = [] - - if not isinstance(config["external_data_tools"], list): - raise ValueError("external_data_tools must be of list type") - - for tool in config["external_data_tools"]: - if "enabled" not in tool or not tool["enabled"]: - tool["enabled"] = False - - if not tool["enabled"]: - continue - - if "type" not in tool or not tool["type"]: - raise ValueError("external_data_tools[].type is required") - - typ = tool["type"] - config = tool["config"] - - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) - - return config, ["external_data_tools"] diff --git a/api/core/app/validators/user_input_form.py b/api/core/app/validators/user_input_form.py deleted file mode 100644 index 249d6745ae..0000000000 --- a/api/core/app/validators/user_input_form.py +++ /dev/null @@ -1,61 +0,0 @@ -import re - - -class UserInputFormValidator: - @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: - """ - Validate and set defaults for user input form - - :param config: app model config args - """ - if not config.get("user_input_form"): - config["user_input_form"] = [] - - if not isinstance(config["user_input_form"], list): - raise ValueError("user_input_form must be a list of objects") - - variables = [] - for item in config["user_input_form"]: - key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") - - form_item = item[key] - if 'label' not in form_item: - raise ValueError("label is required in user_input_form") - - if not isinstance(form_item["label"], str): - raise ValueError("label in user_input_form must be of string type") - - if 'variable' not in form_item: - raise ValueError("variable is required in user_input_form") - - if not isinstance(form_item["variable"], str): - raise ValueError("variable in user_input_form must be of string type") - - pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") - if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") - - variables.append(form_item["variable"]) - - if 'required' not in form_item or not form_item["required"]: - form_item["required"] = False - - if not isinstance(form_item["required"], bool): - raise ValueError("required in user_input_form must be of boolean type") - - if key == "select": - if 'options' not in form_item or not form_item["options"]: - form_item["options"] = [] - - if not isinstance(form_item["options"], list): - raise ValueError("options in user_input_form must be a list of strings") - - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: - raise ValueError("default value in user_input_form must be in the options list") - - return config, ["user_input_form"] diff --git a/api/core/app/workflow/config_validator.py b/api/core/app/workflow/config_validator.py deleted file mode 100644 index e8381146a7..0000000000 --- a/api/core/app/workflow/config_validator.py +++ /dev/null @@ -1,39 +0,0 @@ -from core.app.validators.file_upload import FileUploadValidator -from core.app.validators.moderation import ModerationValidator -from core.app.validators.text_to_speech import TextToSpeechValidator - - -class WorkflowAppConfigValidator: - @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: - """ - Validate for workflow app model config - - :param tenant_id: tenant id - :param config: app model config args - :param only_structure_validate: only validate the structure of the config - """ - related_config_keys = [] - - # file upload validation - config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate - ) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py deleted file mode 100644 index 8a340a8b81..0000000000 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ /dev/null @@ -1,262 +0,0 @@ -import json -import logging -import time -from typing import Any, Optional, Union, cast - -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - -from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.callback_handler.entity.agent_loop import AgentLoop -from core.entities.application_entities import ModelConfigEntity -from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from models.model import Message, MessageAgentThought, MessageChain - - -class AgentLoopGatherCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - raise_error: bool = True - - def __init__(self, model_config: ModelConfigEntity, - queue_manager: AppQueueManager, - message: Message, - message_chain: MessageChain) -> None: - """Initialize callback handler.""" - self.model_config = model_config - self.queue_manager = queue_manager - self.message = message - self.message_chain = message_chain - model_type_instance = self.model_config.provider_model_bundle.model_type_instance - self.model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def agent_loops(self) -> list[AgentLoop]: - return self._agent_loops - - def clear_agent_loops(self) -> None: - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return True - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return True - - def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None: - if not self._current_loop: - # Agent start with a LLM query - self._current_loop = AgentLoop( - position=len(self._agent_loops) + 1, - prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]), - status='llm_started', - started_at=time.perf_counter() - ) - - def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None: - if self._current_loop and self._current_loop.status == 'llm_started': - self._current_loop.status = 'llm_end' - if result.usage: - self._current_loop.prompt_tokens = result.usage.prompt_tokens - else: - self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)] - ) - - completion_message = result.message - if completion_message.tool_calls: - self._current_loop.completion \ - = json.dumps({'function_call': completion_message.tool_calls}) - else: - self._current_loop.completion = completion_message.content - - if result.usage: - self._current_loop.completion_tokens = result.usage.completion_tokens - else: - self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens( - model=self.model_config.model, - credentials=self.model_config.credentials, - prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)] - ) - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - pass - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.debug("Agent on_llm_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - # kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '} - # input_str='action-input' - # serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'} - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = json.dumps({"query": action.tool_input} - if isinstance(action.tool_input, str) else action.tool_input) - completion = None - if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \ - or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction): - thought = action.log.strip() - completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']}) - else: - action_name_position = action.log.index("Action:") if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - - if self._current_loop and self._current_loop.status == 'llm_end': - self._current_loop.status = 'agent_action' - self._current_loop.thought = thought - self._current_loop.tool_name = tool - self._current_loop.tool_input = tool_input - if completion is not None: - self._current_loop.completion = completion - - self._message_agent_thought = self._init_agent_thought() - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - # kwargs={'name': 'Search'} - # llm_prefix='Thought:' - # observation_prefix='Observation: ' - # output='53 years' - - if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None': - self._current_loop.status = 'tool_end' - self._current_loop.tool_output = output - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - logging.debug("Agent on_tool_error: %s", error) - self._agent_loops = [] - self._current_loop = None - self._message_agent_thought = None - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - # Final Answer - if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'): - self._current_loop.status = 'agent_finish' - self._current_loop.completed = True - self._current_loop.completed_at = time.perf_counter() - self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at - self._current_loop.thought = '[DONE]' - self._message_agent_thought = self._init_agent_thought() - - self._complete_agent_thought(self._message_agent_thought) - - self._agent_loops.append(self._current_loop) - self._current_loop = None - self._message_agent_thought = None - elif not self._current_loop and self._agent_loops: - self._agent_loops[-1].status = 'agent_finish' - - def _init_agent_thought(self) -> MessageAgentThought: - message_agent_thought = MessageAgentThought( - message_id=self.message.id, - message_chain_id=self.message_chain.id, - position=self._current_loop.position, - thought=self._current_loop.thought, - tool=self._current_loop.tool_name, - tool_input=self._current_loop.tool_input, - message=self._current_loop.prompt, - message_price_unit=0, - answer=self._current_loop.completion, - answer_price_unit=0, - created_by_role=('account' if self.message.from_source == 'console' else 'end_user'), - created_by=(self.message.from_account_id - if self.message.from_source == 'console' else self.message.from_end_user_id) - ) - - db.session.add(message_agent_thought) - db.session.commit() - - self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER) - - return message_agent_thought - - def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: - loop_message_tokens = self._current_loop.prompt_tokens - loop_answer_tokens = self._current_loop.completion_tokens - - # transform usage - llm_usage = self.model_type_instance._calc_response_usage( - self.model_config.model, - self.model_config.credentials, - loop_message_tokens, - loop_answer_tokens - ) - - message_agent_thought.observation = self._current_loop.tool_output - message_agent_thought.tool_process_data = '' # currently not support - message_agent_thought.message_token = loop_message_tokens - message_agent_thought.message_unit_price = llm_usage.prompt_unit_price - message_agent_thought.message_price_unit = llm_usage.prompt_price_unit - message_agent_thought.answer_token = loop_answer_tokens - message_agent_thought.answer_unit_price = llm_usage.completion_unit_price - message_agent_thought.answer_price_unit = llm_usage.completion_price_unit - message_agent_thought.latency = self._current_loop.latency - message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens - message_agent_thought.total_price = llm_usage.total_price - message_agent_thought.currency = llm_usage.currency - db.session.commit() diff --git a/api/core/callback_handler/entity/agent_loop.py b/api/core/callback_handler/entity/agent_loop.py deleted file mode 100644 index 56634bb19e..0000000000 --- a/api/core/callback_handler/entity/agent_loop.py +++ /dev/null @@ -1,23 +0,0 @@ -from pydantic import BaseModel - - -class AgentLoop(BaseModel): - position: int = 1 - - thought: str = None - tool_name: str = None - tool_input: str = None - tool_output: str = None - - prompt: str = None - prompt_tokens: int = 0 - completion: str = None - completion_tokens: int = 0 - - latency: float = None - - status: str = 'llm_started' - completed: bool = False - - started_at: float = None - completed_at: float = None \ No newline at end of file diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index e49a09d4c4..ca781a55bc 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,6 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.entities.application_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 64c7d1e859..8601cb34e7 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -5,7 +5,7 @@ from typing import Optional from flask import Flask, current_app -from core.entities.application_entities import ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory logger = logging.getLogger(__name__) diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 435074f743..bd896719c2 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import BaseModel +from core.app.app_config.entities import FileUploadEntity from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db @@ -50,7 +51,7 @@ class FileObj(BaseModel): transfer_method: FileTransferMethod url: Optional[str] upload_file_id: Optional[str] - file_config: dict + file_upload_entity: FileUploadEntity @property def data(self) -> Optional[str]: @@ -63,7 +64,7 @@ class FileObj(BaseModel): @property def prompt_message_content(self) -> ImagePromptMessageContent: if self.type == FileType.IMAGE: - image_config = self.file_config.get('image') + image_config = self.file_upload_entity.image_config return ImagePromptMessageContent( data=self.data, diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index c132073578..9d122c4120 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Union import requests +from core.app.app_config.entities import FileUploadEntity from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from extensions.ext_database import db from models.account import Account -from models.model import AppModelConfig, EndUser, MessageFile, UploadFile +from models.model import EndUser, MessageFile, UploadFile from services.file_service import IMAGE_EXTENSIONS @@ -15,18 +16,16 @@ class MessageFileParser: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity, user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg :param files: - :param app_model_config: + :param file_upload_entity: :param user: :return: """ - file_upload_config = app_model_config.file_upload_dict - for file in files: if not isinstance(file, dict): raise ValueError('Invalid file format, must be dict') @@ -45,17 +44,17 @@ class MessageFileParser: raise ValueError('Missing file upload_file_id') # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) + type_file_objs = self._to_file_objs(files, file_upload_entity) # validate files new_files = [] for file_type, file_objs in type_file_objs.items(): if file_type == FileType.IMAGE: # parse and validate files - image_config = file_upload_config.get('image') + image_config = file_upload_entity.image_config # check if image file feature is enabled - if not image_config['enabled']: + if not image_config: continue # Validate number of files @@ -96,27 +95,27 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_upload_config: Optional[dict]) -> list[FileObj]: + def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]: """ transform message files :param files: - :param file_upload_config: + :param file_upload_entity: :return: """ # transform files to file objs - type_file_objs = self._to_file_objs(files, file_upload_config) + type_file_objs = self._to_file_objs(files, file_upload_entity) # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_upload_config: dict) -> dict[FileType, list[FileObj]]: + file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]: """ transform files to file objs :param files: - :param file_upload_config: + :param file_upload_entity: :return: """ type_file_objs: dict[FileType, list[FileObj]] = { @@ -133,7 +132,7 @@ class MessageFileParser: if file.belongs_to == FileBelongsTo.ASSISTANT.value: continue - file_obj = self._to_file_obj(file, file_upload_config) + file_obj = self._to_file_obj(file, file_upload_entity) if file_obj.type not in type_file_objs: continue @@ -141,7 +140,7 @@ class MessageFileParser: return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: + def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj: """ transform file to file obj @@ -156,7 +155,7 @@ class MessageFileParser: transfer_method=transfer_method, url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) else: return FileObj( @@ -166,7 +165,7 @@ class MessageFileParser: transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, upload_file_id=file.upload_file_id or None, - file_config=file_upload_config + file_upload_entity=file_upload_entity ) def _check_image_remote_url(self, url): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 86d6b498da..bff9b9cf1f 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,7 +1,7 @@ import logging import random -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from extensions.ext_hosting_provider import hosting_configuration @@ -10,7 +10,7 @@ from models.provider import ProviderType logger = logging.getLogger(__name__) -def check_moderation(model_config: ModelConfigEntity, text: str) -> bool: +def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config if (moderation_config and moderation_config.enabled is True and 'openai' in hosting_configuration.provider_map diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 00813faef7..4fe150e983 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,5 @@ +from core.app.app_config.entities import FileUploadEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ( @@ -43,12 +45,18 @@ class TokenBufferMemory: for message in messages: files = message.message_files if files: - file_objs = message_file_parser.transform_message_files( - files, - message.app_model_config.file_upload_dict - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - else message.workflow_run.workflow.features_dict.get('file_upload', {}) - ) + if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + else: + file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict) + + if file_upload_entity: + file_objs = message_file_parser.transform_message_files( + files, + file_upload_entity + ) + else: + file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 2129c58d8d..8fbc0c2d50 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,6 +1,6 @@ import logging -from core.entities.application_entities import AppOrchestrationConfigEntity +from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationException from core.moderation.factory import ModerationFactory @@ -10,22 +10,22 @@ logger = logging.getLogger(__name__) class InputModeration: def check(self, app_id: str, tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, + app_config: AppConfig, inputs: dict, query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity + :param app_config: app config :param inputs: inputs :param query: query :return: """ - if not app_orchestration_config_entity.sensitive_word_avoidance: + if not app_config.sensitive_word_avoidance: return False, inputs, query - sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance + sensitive_word_avoidance_config = app_config.sensitive_word_avoidance moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 6d0a1d31f5..129c2a4cd2 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,10 +1,7 @@ from typing import Optional -from core.entities.application_entities import ( - AdvancedCompletionPromptTemplateEntity, - ModelConfigEntity, - PromptTemplateEntity, -) +from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -31,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: prompt_messages = [] model_mode = ModelMode.value_of(model_config.mode) @@ -65,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -113,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -202,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform): role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: ModelConfigEntity) -> dict: + model_config: EasyUIBasedModelConfigEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9c554140b7..7fe8128a49 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> list[PromptMessage]: + model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index af7b695bb3..faf1f888e2 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -3,10 +3,8 @@ import json import os from typing import Optional -from core.entities.application_entities import ( - ModelConfigEntity, - PromptTemplateEntity, -) +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -54,7 +52,7 @@ class SimplePromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> \ + model_config: EasyUIBasedModelConfigEntity) -> \ tuple[list[PromptMessage], Optional[list[str]]]: model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: @@ -83,7 +81,7 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, stops def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, pre_prompt: str, inputs: dict, query: Optional[str] = None, @@ -164,7 +162,7 @@ class SimplePromptTransform(PromptTransform): context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + model_config: EasyUIBasedModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -202,7 +200,7 @@ class SimplePromptTransform(PromptTransform): context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + model_config: EasyUIBasedModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( diff --git a/api/core/rag/retrieval/agent/agent_llm_callback.py b/api/core/rag/retrieval/agent/agent_llm_callback.py deleted file mode 100644 index 5ec549de8e..0000000000 --- a/api/core/rag/retrieval/agent/agent_llm_callback.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -from typing import Optional - -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class AgentLLMCallback(Callback): - - def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None: - self.agent_callback = agent_callback - - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_before_invoke( - prompt_messages=prompt_messages - ) - - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - pass - - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_after_invoke( - result=result - ) - - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.agent_callback.on_llm_error( - error=ex - ) diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index 087b7bfa2c..9b115bc696 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -5,19 +5,17 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance -from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.fake_llm import FakeLLM class LLMChain(LCLLMChain): - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") parameters: dict[str, Any] = {} - agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, @@ -38,7 +36,6 @@ class LLMChain(LCLLMChain): prompt_messages=prompt_messages, stream=False, stop=stop, - callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None, model_parameters=self.parameters ) diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 41a0c54041..84e2b0228f 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage from langchain.tools import BaseTool from pydantic import root_validator -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool @@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity class Config: """Configuration for this pydantic object.""" @@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index 4d7d33038b..700bf0c293 100644 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -206,7 +206,7 @@ Thought: {agent_scratchpad} @classmethod def from_llm_and_tools( cls, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index 7fabf71bed..749e603c5c 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -7,13 +7,12 @@ from langchain.callbacks.manager import Callbacks from langchain.tools import BaseTool from pydantic import BaseModel, Extra +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError -from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent @@ -23,15 +22,14 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr class AgentConfiguration(BaseModel): strategy: PlanningStrategy - model_config: ModelConfigEntity + model_config: EasyUIBasedModelConfigEntity tools: list[BaseTool] - summary_model_config: Optional[ModelConfigEntity] = None + summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None memory: Optional[TokenBufferMemory] = None callbacks: Callbacks = None max_iterations: int = 6 max_execution_time: Optional[float] = None early_stopping_method: str = "generate" - agent_llm_callback: Optional[AgentLLMCallback] = None # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit class Config: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 21e16c4162..8f1221adc7 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -2,9 +2,10 @@ from typing import Optional, cast from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy -from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -17,7 +18,7 @@ from models.dataset import Dataset class DatasetRetrieval: def retrieve(self, tenant_id: str, - model_config: ModelConfigEntity, + model_config: EasyUIBasedModelConfigEntity, config: DatasetEntity, query: str, invoke_from: InvokeFrom, diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 629ed23613..80062e606a 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -2,8 +2,9 @@ from typing import Any from langchain.tools import BaseTool +from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8c335f201f..49eea603dc 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,9 +8,9 @@ from models.provider import Provider, ProviderType @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') - model_config = application_generate_entity.app_orchestration_config_entity.model_config + model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle provider_configuration = provider_model_bundle.configuration @@ -43,7 +43,7 @@ def handle(sender, **kwargs): if used_quota is not None: db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, + Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 69b3a90e44..d49e560a67 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.entities.application_entities import ApplicationGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,10 +9,10 @@ from models.provider import Provider @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') db.session.query(Provider).filter( - Provider.tenant_id == application_generate_entity.tenant_id, - Provider.provider_name == application_generate_entity.app_orchestration_config_entity.model_config.provider + Provider.tenant_id == application_generate_entity.app_config.tenant_id, + Provider.provider_name == application_generate_entity.model_config.provider ).update({'last_used': datetime.utcnow()}) db.session.commit() diff --git a/api/models/model.py b/api/models/model.py index e514ea729b..f8f9a0a3cd 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -105,6 +105,18 @@ class App(db.Model): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant + @property + def is_agent(self) -> bool: + app_model_config = self.app_model_config + if not app_model_config: + return False + if not app_model_config.agent_mode: + return False + if self.app_model_config.agent_mode_dict.get('enabled', False) \ + and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + return True + return False + @property def deleted_tools(self) -> list: # get agent mode tools diff --git a/api/models/workflow.py b/api/models/workflow.py index ff4e944e29..f9c906b85c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -129,7 +129,7 @@ class Workflow(db.Model): def features_dict(self): return self.features if not self.features else json.loads(self.features) - def user_input_form(self): + def user_input_form(self) -> list: # get start node from graph if not self.graph: return [] diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index f2caeb14ff..c84f6fbf45 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,6 +1,6 @@ -from core.app.agent_chat.config_validator import AgentChatAppConfigValidator -from core.app.chat.config_validator import ChatAppConfigValidator -from core.app.completion.config_validator import CompletionAppConfigValidator +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from models.model import AppMode @@ -9,10 +9,10 @@ class AppModelConfigService: @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: - return ChatAppConfigValidator.config_validate(tenant_id, config) + return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: - return AgentChatAppConfigValidator.config_validate(tenant_id, config) + return AgentChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.COMPLETION: - return CompletionAppConfigValidator.config_validate(tenant_id, config) + return CompletionAppConfigManager.config_validate(tenant_id, config) else: raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 8a9639e521..453194feb1 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -4,9 +4,9 @@ from typing import Any, Union from sqlalchemy import and_ -from core.app.app_manager import AppManager -from core.app.validators.model_validator import ModelValidator -from core.entities.application_entities import InvokeFrom +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_manager import EasyUIBasedAppManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser from extensions.ext_database import db from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message @@ -30,7 +30,7 @@ class CompletionService: auto_generate_name = args['auto_generate_name'] \ if 'auto_generate_name' in args else True - if app_model.mode != 'completion': + if app_model.mode != AppMode.COMPLETION.value: if not query: raise ValueError('query is required') @@ -43,6 +43,7 @@ class CompletionService: conversation_id = args['conversation_id'] if 'conversation_id' in args else None conversation = None + app_model_config_dict = None if conversation_id: conversation_filter = [ Conversation.id == args['conversation_id'], @@ -63,42 +64,13 @@ class CompletionService: if conversation.status != 'normal': raise ConversationCompletedError() - if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() - if not app_model_config: - raise AppModelConfigBrokenError() - else: - conversation_override_model_configs = json.loads(conversation.override_model_configs) - - app_model_config = AppModelConfig( - id=conversation.app_model_config_id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) - - if is_model_config_override: - # build new app model config - if 'model' not in args['model_config']: - raise ValueError('model_config.model is required') - - if 'completion_params' not in args['model_config']['model']: - raise ValueError('model_config.model.completion_params is required') - - completion_params = ModelValidator.validate_model_completion_params( - cp=args['model_config']['model']['completion_params'] - ) - - app_model_config_model = app_model_config.model_dict - app_model_config_model['completion_params'] = completion_params - app_model_config.retriever_resource = json.dumps({'enabled': True}) - - app_model_config = app_model_config.copy() - app_model_config.model = json.dumps(app_model_config_model) + if not app_model_config: + raise AppModelConfigBrokenError() else: if app_model.app_model_config_id is None: raise AppModelConfigBrokenError() @@ -113,37 +85,29 @@ class CompletionService: raise Exception("Only account can override model config") # validate config - model_config = AppModelConfigService.validate_configuration( + app_model_config_dict = AppModelConfigService.validate_configuration( tenant_id=app_model.tenant_id, config=args['model_config'], app_mode=AppMode.value_of(app_model.mode) ) - app_model_config = AppModelConfig( - id=app_model_config.id, - app_id=app_model.id, - ) - - app_model_config = app_model_config.from_model_config_dict(model_config) - - # clean input by app_model_config form rules - inputs = cls.get_cleaned_inputs(inputs, app_model_config) - # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - app_model_config, - user - ) + file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] - application_manager = AppManager() + application_manager = EasyUIBasedAppManager() return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, - app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=is_model_config_override, + app_model=app_model, + app_model_config=app_model_config, + app_model_config_dict=app_model_config_dict, user=user, invoke_from=invoke_from, inputs=inputs, @@ -189,17 +153,19 @@ class CompletionService: # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_objs = message_file_parser.transform_message_files( - message.files, app_model_config - ) + file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.transform_message_files( + message.files, file_upload_entity + ) + else: + file_objs = [] - application_manager = AppManager() + application_manager = EasyUIBasedAppManager() return application_manager.generate( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_model_config_id=app_model_config.id, + app_model=app_model, + app_model_config=current_app_model_config, app_model_config_dict=app_model_config.to_dict(), - app_model_config_override=True, user=user, invoke_from=invoke_from, inputs=message.inputs, @@ -212,46 +178,3 @@ class CompletionService: } ) - @classmethod - def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - # Filter input variables from form configuration, handle required fields, default values, and option values - input_form_config = app_model_config.user_input_form_list - for config in input_form_config: - input_config = list(config.values())[0] - variable = input_config["variable"] - - input_type = list(config.keys())[0] - - if variable not in user_inputs or not user_inputs[variable]: - if input_type == "external_data_tool": - continue - if "required" in input_config and input_config["required"]: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = input_config["default"] if "default" in input_config else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if input_type == "select": - options = input_config["options"] if "options" in input_config else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if 'max_length' in input_config: - max_length = input_config['max_length'] - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 6c0182dd9e..d62f198014 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,16 +1,9 @@ import json from typing import Optional -from core.app.app_manager import AppManager -from core.entities.application_entities import ( - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - FileUploadEntity, - ModelConfigEntity, - PromptTemplateEntity, - VariableEntity, -) +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, FileUploadEntity +from core.app.app_manager import EasyUIBasedAppManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -36,7 +29,7 @@ class WorkflowConverter: - basic mode of chatbot app - - advanced mode of assistant app + - expert mode of chatbot app - completion app @@ -86,14 +79,11 @@ class WorkflowConverter: # get new app mode new_app_mode = self._get_new_app_mode(app_model) - app_model_config_dict = app_model_config.to_dict() - # convert app model config - application_manager = AppManager() - app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( - tenant_id=app_model.tenant_id, - app_model_config_dict=app_model_config_dict, - skip_check=True + application_manager = EasyUIBasedAppManager() + app_config = application_manager.convert_to_app_config( + app_model=app_model, + app_model_config=app_model_config ) # init workflow graph @@ -113,27 +103,27 @@ class WorkflowConverter: # convert to start node start_node = self._convert_to_start_node( - variables=app_orchestration_config_entity.variables + variables=app_config.variables ) graph['nodes'].append(start_node) # convert to http request node - if app_orchestration_config_entity.external_data_variables: + if app_config.external_data_variables: http_request_nodes = self._convert_to_http_request_node( app_model=app_model, - variables=app_orchestration_config_entity.variables, - external_data_variables=app_orchestration_config_entity.external_data_variables + variables=app_config.variables, + external_data_variables=app_config.external_data_variables ) for http_request_node in http_request_nodes: graph = self._append_node(graph, http_request_node) # convert to knowledge retrieval node - if app_orchestration_config_entity.dataset: + if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=app_orchestration_config_entity.dataset + dataset_config=app_config.dataset ) if knowledge_retrieval_node: @@ -143,9 +133,9 @@ class WorkflowConverter: llm_node = self._convert_to_llm_node( new_app_mode=new_app_mode, graph=graph, - model_config=app_orchestration_config_entity.model_config, - prompt_template=app_orchestration_config_entity.prompt_template, - file_upload=app_orchestration_config_entity.file_upload + model_config=app_config.model, + prompt_template=app_config.prompt_template, + file_upload=app_config.additional_features.file_upload ) graph = self._append_node(graph, llm_node) @@ -155,6 +145,8 @@ class WorkflowConverter: graph = self._append_node(graph, end_node) + app_model_config_dict = app_config.app_model_config_dict + # features if new_app_mode == AppMode.ADVANCED_CHAT: features = { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 102c861733..c9efd056ff 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,8 +2,8 @@ import json from datetime import datetime from typing import Optional -from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator -from core.app.workflow.config_validator import WorkflowAppConfigValidator +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -162,13 +162,13 @@ class WorkflowService: def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: - return AdvancedChatAppConfigValidator.config_validate( + return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: - return WorkflowAppConfigValidator.config_validate( + return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 69acb23681..4357c6405c 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock import pytest -from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ - ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity +from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ + ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole @@ -137,11 +137,11 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="https://example.com/image1.jpg", - file_config={ - "image": { + file_upload_entity=FileUploadEntity( + image_config={ "detail": "high", } - } + ) ) ] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 8a260b0507..9796fc5558 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.entities.application_entities import ModelConfigEntity +from core.app.app_config.entities import ModelConfigEntity from core.entities.provider_configuration import ProviderModelBundle from core.model_runtime.entities.message_entities import UserPromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index a95a6dc52f..70f6070c6b 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.entities.application_entities import ModelConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): def test__get_chat_model_prompt_messages(): - model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' @@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): - model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index d4edc73410..0ca8ae135c 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ +from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity from core.helper import encrypter From 2eaae6742a9ad9e450a854a886cca544880f01b7 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:26 +0800 Subject: [PATCH 092/200] lint fix --- api/core/agent/base_agent_runner.py | 7 ++++--- api/core/agent/cot_agent_runner.py | 2 +- api/core/app/app_manager.py | 8 ++++---- api/core/memory/token_buffer_memory.py | 1 - api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/core/tools/tool/dataset_retriever_tool.py | 2 +- api/services/workflow/workflow_converter.py | 11 +++++++++-- 8 files changed, 21 insertions(+), 14 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 529240aecb..f22ca7653f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -9,12 +9,13 @@ from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, - InvokeFrom, EasyUIBasedModelConfigEntity, + EasyUIBasedModelConfigEntity, + InvokeFrom, ) +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 5b345f4da0..8b444ef3be 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,8 +4,8 @@ from collections.abc import Generator from typing import Literal, Union from core.agent.base_agent_runner import BaseAgentRunner -from core.app.app_queue_manager import PublishFrom from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit +from core.app.app_queue_manager import PublishFrom from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py index 98ebe2c87d..ea8a97f878 100644 --- a/api/core/app/app_manager.py +++ b/api/core/app/app_manager.py @@ -9,26 +9,26 @@ from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, EasyUIBasedAppConfig, VariableEntity +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, VariableEntity +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner -from core.app.generate_task_pipeline import GenerateTaskPipeline from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, InvokeFrom, ) +from core.app.generate_task_pipeline import GenerateTaskPipeline from core.file.file_obj import FileObj from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, EndUser, Message, MessageFile, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile logger = logging.getLogger(__name__) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4fe150e983..471400f09b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,4 +1,3 @@ -from core.app.app_config.entities import FileUploadEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 129c2a4cd2..cdd03b85f1 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8f1221adc7..37581f1e92 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,9 +3,9 @@ from typing import Optional, cast from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy -from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 80062e606a..1522d3af09 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -3,8 +3,8 @@ from typing import Any from langchain.tools import BaseTool from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index d62f198014..b3061cc255 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,8 +1,15 @@ import json from typing import Optional -from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ - DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, FileUploadEntity +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + ModelConfigEntity, + PromptTemplateEntity, + VariableEntity, +) from core.app.app_manager import EasyUIBasedAppManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode From b80092ea1243fc039a45d85b04f1744f9c20db49 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 02:40:31 +0800 Subject: [PATCH 093/200] lint fix --- api/core/agent/entities.py | 2 +- api/core/app/app_config/base_app_config_manager.py | 7 ++++--- .../easy_ui_based_app/model_config/converter.py | 1 - .../easy_ui_based_app/model_config/manager.py | 2 +- .../easy_ui_based_app/prompt_template/manager.py | 7 +++++-- .../app_config/easy_ui_based_app/variables/manager.py | 5 ++--- .../app_config/features/opening_statement/manager.py | 3 +-- api/core/app/apps/advanced_chat/app_config_manager.py | 7 ++++--- api/core/app/apps/agent_chat/app_config_manager.py | 11 ++++++----- api/core/app/apps/base_app_runner.py | 9 +++++---- api/core/app/apps/chat/app_config_manager.py | 9 +++++---- api/core/app/apps/chat/app_runner.py | 4 ++-- api/core/app/apps/completion/app_config_manager.py | 4 ++-- api/core/app/apps/completion/app_runner.py | 4 ++-- api/core/app/apps/workflow/app_config_manager.py | 2 +- 15 files changed, 41 insertions(+), 36 deletions(-) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 0fbfdc2636..e7016d6030 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal, Any, Union, Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index b3c773203d..e09aa03766 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Optional, Union from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -6,8 +6,9 @@ from core.app.app_config.features.more_like_this.manager import MoreLikeThisConf from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from models.model import AppModelConfig diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 05fcb10791..610e9bce32 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -2,7 +2,6 @@ from typing import cast from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity - from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.model_entities import ModelType diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5cca2bc1a7..730a9527cf 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,5 +1,5 @@ from core.app.app_config.entities import ModelConfigEntity -from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory from core.provider_manager import ProviderManager diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 5629d0d09e..1f410758aa 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,8 @@ -from core.app.app_config.entities import PromptTemplateEntity, \ - AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity +from core.app.app_config.entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode from models.model import AppMode diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index ff962a5439..1237da502b 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,13 +1,12 @@ import re -from typing import Tuple -from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity from core.external_data_tool.factory import ExternalDataToolFactory class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> Tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 6183c6e749..0d8a71bfcf 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,9 +1,8 @@ -from typing import Tuple class OpeningStatementConfigManager: @classmethod - def convert(cls, config: dict) -> Tuple[str, list]: + def convert(cls, config: dict) -> tuple[str, list]: """ Convert model config to model config diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index ab7857c4ad..d0909ead70 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -5,11 +5,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import AppMode, App +from models.model import App, AppMode from models.workflow import Workflow diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 96dac4bd01..55a04832aa 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -3,22 +3,23 @@ from typing import Optional from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager -from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, DatasetEntity +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 93f819af08..64c1a46491 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -2,14 +2,15 @@ import time from collections.abc import Generator from typing import Optional, Union, cast -from core.app.app_config.entities import PromptTemplateEntity, ExternalDataVariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature -from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, - InvokeFrom, EasyUIBasedModelConfigEntity, + EasyUIBasedModelConfigEntity, + InvokeFrom, ) +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 62b2aaae5a..ff0195563e 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,20 +1,21 @@ from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import \ - SuggestedQuestionsAfterAnswerConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig class ChatAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 403a2d4476..1b256f11c4 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -2,12 +2,12 @@ import logging from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, ) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index b920f369b5..6bdb7cc4b3 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,16 +1,16 @@ from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import AppMode, App, AppModelConfig +from models.model import App, AppMode, AppModelConfig class CompletionAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 8f0f191d45..d60e14aaeb 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -2,12 +2,12 @@ import logging from typing import cast from core.app.app_queue_manager import AppQueueManager -from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.base_app_runner import AppRunner -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( EasyUIBasedAppGenerateEntity, ) +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance from core.moderation.base import ModerationException from core.rag.retrieval.dataset_retrieval import DatasetRetrieval diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 35da72b63e..194339a23b 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -4,7 +4,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import AppMode, App +from models.model import App, AppMode from models.workflow import Workflow From 06b05163f673f533e48810e72eef9a1e60cf89f4 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 15:53:40 +0800 Subject: [PATCH 094/200] update app import response --- api/controllers/console/app/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 98636fa95f..db23a028cd 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -76,7 +76,7 @@ class AppImportApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_fields) + @marshal_with(app_detail_fields_with_site) @cloud_edition_billing_resource_check('apps') def post(self): """Import app""" From 09dfe80718d85fef7e31a2547d174b2af6355fd1 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 2 Mar 2024 15:57:34 +0800 Subject: [PATCH 095/200] add app copy api --- api/controllers/console/app/app.py | 29 ++++++++++++++++++++++++++++- api/services/app_service.py | 5 +++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index db23a028cd..7b2411b96f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -93,7 +93,7 @@ class AppImportApi(Resource): args = parser.parse_args() app_service = AppService() - app = app_service.import_app(current_user.current_tenant_id, args, current_user) + app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user) return app, 201 @@ -180,6 +180,32 @@ class AppApi(Resource): return {'result': 'success'}, 204 +class AppCopyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_detail_fields_with_site) + def post(self, app_model): + """Copy app""" + # The role of the current user in the ta table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, location='json') + parser.add_argument('description', type=str, location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app_service = AppService() + data = app_service.export_app(app_model) + app = app_service.import_app(current_user.current_tenant_id, data, args, current_user) + + return app, 201 + + class AppExportApi(Resource): @setup_required @login_required @@ -266,6 +292,7 @@ class AppApiStatus(Resource): api.add_resource(AppListApi, '/apps') api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppApi, '/apps/') +api.add_resource(AppCopyApi, '/apps//copy') api.add_resource(AppExportApi, '/apps//export') api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppIconApi, '/apps//icon') diff --git a/api/services/app_service.py b/api/services/app_service.py index e0a7835cb7..f1d0e3df19 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -124,15 +124,16 @@ class AppService: return app - def import_app(self, tenant_id: str, args: dict, account: Account) -> App: + def import_app(self, tenant_id: str, data: str, args: dict, account: Account) -> App: """ Import app :param tenant_id: tenant id + :param data: import data :param args: request args :param account: Account instance """ try: - import_data = yaml.safe_load(args['data']) + import_data = yaml.safe_load(data) except yaml.YAMLError as e: raise ValueError("Invalid YAML format in data argument.") From e498efce2d79587628bcb8c904af2843971e8549 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 3 Mar 2024 04:18:38 +0800 Subject: [PATCH 096/200] refactor app generate --- api/controllers/console/app/completion.py | 6 +- api/core/agent/base_agent_runner.py | 13 +- .../model_config/converter.py | 8 +- api/core/app/app_manager.py | 468 ------------------ .../apps/advanced_chat/app_config_manager.py | 8 +- .../app/apps/agent_chat/app_config_manager.py | 25 +- api/core/app/apps/agent_chat/app_generator.py | 194 ++++++++ api/core/app/apps/agent_chat/app_runner.py | 7 +- api/core/app/apps/base_app_generator.py | 42 ++ api/core/app/apps/base_app_runner.py | 13 +- api/core/app/apps/chat/app_config_manager.py | 25 +- api/core/app/apps/chat/app_generator.py | 194 ++++++++ api/core/app/apps/chat/app_runner.py | 4 +- .../app/apps/completion/app_config_manager.py | 21 +- api/core/app/apps/completion/app_generator.py | 292 +++++++++++ api/core/app/apps/completion/app_runner.py | 4 +- .../app/apps/message_based_app_generator.py | 251 ++++++++++ .../app/apps/workflow/app_config_manager.py | 2 +- api/core/app/entities/app_invoke_entities.py | 74 ++- .../hosting_moderation/hosting_moderation.py | 2 +- api/core/app/generate_task_pipeline.py | 18 +- api/core/helper/moderation.py | 4 +- api/core/prompt/advanced_prompt_transform.py | 10 +- api/core/prompt/prompt_transform.py | 6 +- api/core/prompt/simple_prompt_transform.py | 10 +- api/core/rag/retrieval/agent/llm_chain.py | 4 +- .../agent/multi_dataset_router_agent.py | 6 +- .../structed_multi_dataset_router_agent.py | 4 +- .../retrieval/agent_based_dataset_executor.py | 6 +- api/core/rag/retrieval/dataset_retrieval.py | 4 +- .../deduct_quota_when_messaeg_created.py | 4 +- ...vider_last_used_at_when_messaeg_created.py | 4 +- api/services/completion_service.py | 209 ++------ api/services/workflow/workflow_converter.py | 39 +- .../prompt/test_simple_prompt_transform.py | 6 +- 35 files changed, 1236 insertions(+), 751 deletions(-) delete mode 100644 api/core/app/app_manager.py create mode 100644 api/core/app/apps/agent_chat/app_generator.py create mode 100644 api/core/app/apps/base_app_generator.py create mode 100644 api/core/app/apps/chat/app_generator.py create mode 100644 api/core/app/apps/completion/app_generator.py create mode 100644 api/core/app/apps/message_based_app_generator.py diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index ed1522c0cd..fd6cfadfef 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -59,8 +59,7 @@ class CompletionMessageApi(Resource): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) @@ -126,8 +125,7 @@ class ChatMessageApi(Resource): user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming, - is_model_config_override=True + streaming=streaming ) return compact_response(response) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f22ca7653f..ef530b9122 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -10,9 +10,8 @@ from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - EasyUIBasedModelConfigEntity, - InvokeFrom, + ModelConfigWithCredentialsEntity, + InvokeFrom, AgentChatAppGenerateEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -49,9 +48,9 @@ logger = logging.getLogger(__name__) class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, - application_generate_entity: EasyUIBasedAppGenerateEntity, + application_generate_entity: AgentChatAppGenerateEntity, app_config: AgentChatAppConfig, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, queue_manager: AppQueueManager, message: Message, @@ -123,8 +122,8 @@ class BaseAgentRunner(AppRunner): else: self.stream_tool_call = False - def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ - -> EasyUIBasedAppGenerateEntity: + def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ + -> AgentChatAppGenerateEntity: """ Repack app generate entity """ diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 610e9bce32..5c9b2cfec7 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,7 +1,7 @@ from typing import cast from core.app.app_config.entities import EasyUIBasedAppConfig -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.model_entities import ModelType @@ -9,11 +9,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.provider_manager import ProviderManager -class EasyUIBasedModelConfigEntityConverter: +class ModelConfigConverter: @classmethod def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) \ - -> EasyUIBasedModelConfigEntity: + -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -91,7 +91,7 @@ class EasyUIBasedModelConfigEntityConverter: if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") - return EasyUIBasedModelConfigEntity( + return ModelConfigWithCredentialsEntity( provider=model_config.provider, model=model_config.model, model_schema=model_schema, diff --git a/api/core/app/app_manager.py b/api/core/app/app_manager.py deleted file mode 100644 index ea8a97f878..0000000000 --- a/api/core/app/app_manager.py +++ /dev/null @@ -1,468 +0,0 @@ -import json -import logging -import threading -import uuid -from collections.abc import Generator -from typing import Any, Optional, Union, cast - -from flask import Flask, current_app -from pydantic import ValidationError - -from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter -from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, VariableEntity -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom -from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager -from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.apps.chat.app_config_manager import ChatAppConfigManager -from core.app.apps.chat.app_runner import ChatAppRunner -from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.app.apps.completion.app_runner import CompletionAppRunner -from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - InvokeFrom, -) -from core.app.generate_task_pipeline import GenerateTaskPipeline -from core.file.file_obj import FileObj -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from extensions.ext_database import db -from models.account import Account -from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile - -logger = logging.getLogger(__name__) - - -class EasyUIBasedAppManager: - - def generate(self, app_model: App, - app_model_config: AppModelConfig, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - inputs: dict[str, str], - app_model_config_dict: Optional[dict] = None, - query: Optional[str] = None, - files: Optional[list[FileObj]] = None, - conversation: Optional[Conversation] = None, - stream: bool = False, - extras: Optional[dict[str, Any]] = None) \ - -> Union[dict, Generator]: - """ - Generate App response. - - :param app_model: App - :param app_model_config: app model config - :param user: account or end user - :param invoke_from: invoke from source - :param inputs: inputs - :param app_model_config_dict: app model config dict - :param query: query - :param files: file obj list - :param conversation: conversation - :param stream: is stream - :param extras: extras - """ - # init task id - task_id = str(uuid.uuid4()) - - # convert to app config - app_config = self.convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config, - app_model_config_dict=app_model_config_dict, - conversation=conversation - ) - - # init application generate entity - application_generate_entity = EasyUIBasedAppGenerateEntity( - task_id=task_id, - app_config=app_config, - model_config=EasyUIBasedModelConfigEntityConverter.convert(app_config), - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), - query=query.replace('\x00', '') if query else None, - files=files if files else [], - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras - ) - - if not stream and application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: - raise ValueError("Agent app is not supported in blocking mode.") - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = AppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) - - worker_thread.start() - - # return response or stream generator - return self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - stream=stream - ) - - def convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig, - app_model_config_dict: Optional[dict] = None, - conversation: Optional[Conversation] = None) -> EasyUIBasedAppConfig: - if app_model_config_dict: - config_from = EasyUIBasedAppModelConfigFrom.ARGS - elif conversation: - config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG - else: - config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value - app_config = AgentChatAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - elif app_mode == AppMode.COMPLETION: - app_config = CompletionAppConfigManager.config_convert( - app_model=app_model, - config_from=config_from, - app_model_config=app_model_config, - config_dict=app_model_config_dict - ) - else: - raise ValueError("Invalid app mode") - - return app_config - - def _get_cleaned_inputs(self, user_inputs: dict, app_config: EasyUIBasedAppConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - - # Filter input variables from form configuration, handle required fields, default values, and option values - variables = app_config.variables - for variable_config in variables: - variable = variable_config.variable - - if variable not in user_inputs or not user_inputs[variable]: - if variable_config.required: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" - continue - - value = user_inputs[variable] - - if value: - if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - - if variable_config.type == VariableEntity.Type.SELECT: - options = variable_config.options if variable_config.options is not None else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - else: - if variable_config.max_length is not None: - max_length = variable_config.max_length - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - filtered_inputs[variable] = value.replace('\x00', '') if value else None - - return filtered_inputs - - def _generate_worker(self, flask_app: Flask, - application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: - """ - Generate worker in a new thread. - :param flask_app: Flask app - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation_id: conversation ID - :param message_id: message ID - :return: - """ - with flask_app.app_context(): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - if application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT: - # agent app - runner = AgentChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - elif application_generate_entity.app_config.app_mode == AppMode.CHAT: - # chatbot app - runner = ChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - elif application_generate_entity.app_config.app_mode == AppMode.COMPLETION: - # completion app - runner = CompletionAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) - else: - raise ValueError("Invalid app mode") - except ConversationTaskStoppedException: - pass - except InvokeAuthorizationError: - queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER - ) - except ValidationError as e: - logger.exception("Validation Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except Exception as e: - logger.exception("Unknown Error when generating") - queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.remove() - - def _handle_response(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: - """ - Handle response. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param stream: is stream - :return: - """ - # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - - try: - return generate_task_pipeline.process(stream=stream) - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() - else: - logger.exception(e) - raise e - finally: - db.session.remove() - - def _init_generate_records(self, application_generate_entity: EasyUIBasedAppGenerateEntity) \ - -> tuple[Conversation, Message]: - """ - Initialize generate records - :param application_generate_entity: application generate entity - :return: - """ - model_type_instance = application_generate_entity.model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_schema = model_type_instance.get_model_schema( - model=application_generate_entity.model_config.model, - credentials=application_generate_entity.model_config.credentials - ) - - app_config = application_generate_entity.app_config - - app_record = (db.session.query(App) - .filter(App.id == app_config.app_id).first()) - - app_mode = app_record.mode - - # get from source - end_user_id = None - account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' - end_user_id = application_generate_entity.user_id - else: - from_source = 'console' - account_id = application_generate_entity.user_id - - override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS: - override_model_configs = app_config.app_model_config_dict - - introduction = '' - if app_mode == 'chat': - # get conversation introduction - introduction = self._get_conversation_introduction(application_generate_entity) - - if not application_generate_entity.conversation_id: - conversation = Conversation( - app_id=app_record.id, - app_model_config_id=app_config.app_model_config_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_mode, - name='New conversation', - inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status='normal', - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - ) - - db.session.add(conversation) - db.session.commit() - else: - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == application_generate_entity.conversation_id, - Conversation.app_id == app_record.id - ).first() - ) - - currency = model_schema.pricing.currency if model_schema.pricing else 'USD' - - message = Message( - app_id=app_record.id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query or "", - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency=currency, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - agent_based=app_config.app_mode == AppMode.AGENT_CHAT, - ) - - db.session.add(message) - db.session.commit() - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, - upload_file_id=file.upload_file_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message - - def _get_conversation_introduction(self, application_generate_entity: EasyUIBasedAppGenerateEntity) -> str: - """ - Get conversation introduction - :param application_generate_entity: application generate entity - :return: conversation introduction - """ - app_config = application_generate_entity.app_config - introduction = app_config.additional_features.opening_statement - - if introduction: - try: - inputs = application_generate_entity.inputs - prompt_template = PromptTemplateParser(template=introduction) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - introduction = prompt_template.format(prompt_inputs) - except KeyError: - pass - - return introduction - - def _get_conversation(self, conversation_id: str) -> Conversation: - """ - Get conversation by conversation id - :param conversation_id: conversation id - :return: conversation - """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) - - return conversation - - def _get_message(self, message_id: str) -> Message: - """ - Get message by message id - :param message_id: message id - :return: message - """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) - - return message diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index d0909ead70..72ba4c33d4 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -10,7 +12,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import App, AppMode +from models.model import App, AppMode, Conversation from models.workflow import Workflow @@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, + workflow: Workflow, + conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_config = AdvancedChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 55a04832aa..57214f924a 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -19,7 +19,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> AgentChatAppConfig: + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param conversation: conversation + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = AgentChatAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py new file mode 100644 index 0000000000..1ab456d822 --- /dev/null +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -0,0 +1,194 @@ +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class AgentChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = AgentChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AgentChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AgentChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2f1de8f108..6bae5e1648 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -7,7 +7,8 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \ + AgentChatAppGenerateEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner): """ Agent Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -292,7 +293,7 @@ class AgentChatAppRunner(AppRunner): 'pool': db_variables.variables }) - def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, + def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, message: Message) -> LLMUsage: """ Get usage of all agent thoughts diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py new file mode 100644 index 0000000000..65764021aa --- /dev/null +++ b/api/core/app/apps/base_app_generator.py @@ -0,0 +1,42 @@ +from core.app.app_config.entities import VariableEntity, AppConfig + + +class BaseAppGenerator: + def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + variables = app_config.variables + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs + diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 64c1a46491..ee70f161a2 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,9 +5,8 @@ from typing import Optional, Union, cast from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, - EasyUIBasedModelConfigEntity, - InvokeFrom, + ModelConfigWithCredentialsEntity, + InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature @@ -27,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -83,7 +82,7 @@ class AppRunner: return rest_tokens - def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, + def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance @@ -119,7 +118,7 @@ class AppRunner: model_config.parameters[parameter_rule.name] = max_tokens def organize_prompt_messages(self, app_record: App, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], files: list[FileObj], @@ -292,7 +291,7 @@ class AppRunner: def moderation_for_inputs(self, app_id: str, tenant_id: str, - app_generate_entity: EasyUIBasedAppGenerateEntity, + app_generate_entity: AppGenerateEntity, inputs: dict, query: str) -> tuple[bool, dict, str]: """ diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index ff0195563e..ac69a92823 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -15,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig): class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> ChatAppConfig: + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param conversation: conversation + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + elif conversation: + config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = ChatAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py new file mode 100644 index 0000000000..712822f3a5 --- /dev/null +++ b/api/core/app/apps/chat/app_generator.py @@ -0,0 +1,194 @@ +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser + +logger = logging.getLogger(__name__) + + +class ChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = ChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = ChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 1b256f11c4..57aca9d3e6 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, + ChatAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: ChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 6bdb7cc4b3..77a1443037 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, - config_from: EasyUIBasedAppModelConfigFrom, + def get_app_config(cls, app_model: App, app_model_config: AppModelConfig, - config_dict: Optional[dict] = None) -> CompletionAppConfig: + override_config_dict: Optional[dict] = None) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model - :param config_from: app model config from :param app_model_config: app model config - :param config_dict: app model config dict + :param override_config_dict: app model config dict :return: """ - config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) + if override_config_dict: + config_from = EasyUIBasedAppModelConfigFrom.ARGS + else: + config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + + if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + app_model_config_dict = app_model_config.to_dict() + config_dict = app_model_config_dict.copy() + else: + config_dict = override_config_dict app_config = CompletionAppConfig( tenant_id=app_model.tenant_id, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py new file mode 100644 index 0000000000..d258a3bd9d --- /dev/null +++ b/api/core/app/apps/completion/app_generator.py @@ -0,0 +1,292 @@ +import json +import logging +import threading +import uuid +from typing import Union, Any, Generator + +from flask import current_app, Flask +from pydantic import ValidationError + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, CompletionAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser, Message +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + +logger = logging.getLogger(__name__) + + +class CompletionAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = {} + + # get conversation + conversation = None + + # get app model config + app_model_config = self._get_app_model_config( + app_model=app_model, + conversation=conversation + ) + + # validate override model config + override_model_config_dict = None + if args.get('model_config'): + if invoke_from != InvokeFrom.DEBUGGER: + raise ValueError('Only in App debug mode can override model config') + + # validate config + override_model_config_dict = CompletionAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, + config=args.get('model_config') + ) + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get message + message = self._get_message(message_id) + + # chatbot app + runner = CompletionAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def generate_more_like_this(self, app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param message_id: message ID + :param user: account or end user + :param invoke_from: invoke from source + :param stream: is stream + """ + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + override_model_config_dict = app_model_config.to_dict() + model_dict = override_model_config_dict['model'] + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + override_model_config_dict['model'] = model_dict + + # parse files + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + message.files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_model_config_dict + ) + + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_config=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={} + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index d60e14aaeb..c5b8ca6c9a 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( - EasyUIBasedAppGenerateEntity, + CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance @@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def run(self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message) -> None: """ diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py new file mode 100644 index 0000000000..783c6c6ee5 --- /dev/null +++ b/api/core/app/apps/message_based_app_generator.py @@ -0,0 +1,251 @@ +import json +import logging +from typing import Union, Generator, Optional + +from sqlalchemy import and_ + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_queue_manager import ConversationTaskStoppedException, AppQueueManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity, AppGenerateEntity, \ + CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity +from core.app.generate_task_pipeline import GenerateTaskPipeline +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from extensions.ext_database import db +from models.account import Account +from models.model import Conversation, Message, AppMode, MessageFile, App, EndUser, AppModelConfig +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError + +logger = logging.getLogger(__name__) + + +class MessageBasedAppGenerator(BaseAppGenerator): + + def _handle_response(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = GenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() + + def _get_conversation_by_user(self, app_model: App, conversation_id: str, + user: Union[Account, EndUser]) -> Conversation: + conversation_filter = [ + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.status == 'normal' + ] + + if isinstance(user, Account): + conversation_filter.append(Conversation.from_account_id == user.id) + else: + conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) + + conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != 'normal': + raise ConversationCompletedError() + + return conversation + + def _get_app_model_config(self, app_model: App, + conversation: Optional[Conversation] = None) \ + -> AppModelConfig: + if conversation: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + if app_model.app_model_config_id is None: + raise AppModelConfigBrokenError() + + app_model_config = app_model.app_model_config + + if not app_model_config: + raise AppModelConfigBrokenError() + + return app_model_config + + def _init_generate_records(self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], + conversation: Optional[Conversation] = None) \ + -> tuple[Conversation, Message]: + """ + Initialize generate records + :param application_generate_entity: application generate entity + :return: + """ + app_config = application_generate_entity.app_config + + # get from source + end_user_id = None + account_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + from_source = 'api' + end_user_id = application_generate_entity.user_id + else: + from_source = 'console' + account_id = application_generate_entity.user_id + + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ + and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + override_model_configs = app_config.app_model_config_dict + + # get conversation introduction + introduction = self._get_conversation_introduction(application_generate_entity) + + if not conversation: + conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=app_config.app_model_config_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_config.app_mode.value, + name='New conversation', + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status='normal', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.commit() + + message = Message( + app_id=app_config.app_id, + model_provider=application_generate_entity.model_config.provider, + model_id=application_generate_entity.model_config.model, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=conversation.id, + inputs=application_generate_entity.inputs, + query=application_generate_entity.query or "", + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency='USD', + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id + ) + + db.session.add(message) + db.session.commit() + + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type.value, + transfer_method=file.transfer_method.value, + belongs_to='user', + url=file.url, + upload_file_id=file.upload_file_id, + created_by_role=('account' if account_id else 'end_user'), + created_by=account_id or end_user_id, + ) + db.session.add(message_file) + db.session.commit() + + return conversation, message + + def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: + """ + Get conversation introduction + :param application_generate_entity: application generate entity + :return: conversation introduction + """ + app_config = application_generate_entity.app_config + introduction = app_config.additional_features.opening_statement + + if introduction: + try: + inputs = application_generate_entity.inputs + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + introduction = prompt_template.format(prompt_inputs) + except KeyError: + pass + + return introduction + + def _get_conversation(self, conversation_id: str) -> Conversation: + """ + Get conversation by conversation id + :param conversation_id: conversation id + :return: conversation + """ + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + return conversation + + def _get_message(self, message_id: str) -> Message: + """ + Get message by message id + :param message_id: message id + :return: message + """ + message = ( + db.session.query(Message) + .filter(Message.id == message_id) + .first() + ) + + return message diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 194339a23b..91bab1b218 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): class WorkflowAppConfigManager(BaseAppConfigManager): @classmethod - def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: features_dict = workflow.features_dict app_config = WorkflowAppConfig( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index fae9044fc3..9097345674 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.model_entities import AIModelEntity @@ -49,9 +49,9 @@ class InvokeFrom(Enum): return 'dev' -class EasyUIBasedModelConfigEntity(BaseModel): +class ModelConfigWithCredentialsEntity(BaseModel): """ - Model Config Entity. + Model Config With Credentials Entity. """ provider: str model: str @@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel): stop: list[str] = [] -class EasyUIBasedAppGenerateEntity(BaseModel): +class AppGenerateEntity(BaseModel): """ - EasyUI Based Application Generate Entity. + App Generate Entity. """ task_id: str # app config - app_config: EasyUIBasedAppConfig - model_config: EasyUIBasedModelConfigEntity + app_config: AppConfig - conversation_id: Optional[str] = None inputs: dict[str, str] - query: Optional[str] = None files: list[FileObj] = [] user_id: str + # extras stream: bool invoke_from: InvokeFrom @@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel): extras: dict[str, Any] = {} -class WorkflowUIBasedAppGenerateEntity(BaseModel): +class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ - Workflow UI Based Application Generate Entity. + Chat Application Generate Entity. """ - task_id: str + # app config + app_config: EasyUIBasedAppConfig + model_config: ModelConfigWithCredentialsEntity + query: Optional[str] = None + + +class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Chat Application Generate Entity. + """ + conversation_id: Optional[str] = None + + +class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Completion Application Generate Entity. + """ + pass + + +class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): + """ + Agent Chat Application Generate Entity. + """ + conversation_id: Optional[str] = None + + +class AdvancedChatAppGenerateEntity(AppGenerateEntity): + """ + Advanced Chat Application Generate Entity. + """ # app config app_config: WorkflowUIBasedAppConfig - inputs: dict[str, str] - files: list[FileObj] = [] - user_id: str - # extras - stream: bool - invoke_from: InvokeFrom - - # extra parameters - extras: dict[str, Any] = {} - - -class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity): conversation_id: Optional[str] = None - query: str + query: Optional[str] = None + + +class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity): + """ + Workflow UI Based Application Generate Entity. + """ + # app config + app_config: WorkflowUIBasedAppConfig diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index ec316248a2..7d555328db 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 359369ef59..926b0e128c 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,7 +7,8 @@ from typing import Optional, Union, cast from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \ + AgentChatAppGenerateEntity from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, @@ -39,7 +40,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought, MessageFile +from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -58,7 +59,11 @@ class GenerateTaskPipeline: GenerateTaskPipeline is a class that generate stream output and state management for Application. """ - def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, + def __init__(self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity + ], queue_manager: AppQueueManager, conversation: Conversation, message: Message) -> None: @@ -433,6 +438,7 @@ class GenerateTaskPipeline: self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price + self._message.currency = usage.currency db.session.commit() @@ -440,7 +446,11 @@ class GenerateTaskPipeline: self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, + is_first_message=self._application_generate_entity.app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.ADVANCED_CHAT + ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index bff9b9cf1f..20feae8554 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,7 +1,7 @@ import logging import random -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from extensions.ext_hosting_provider import hosting_configuration @@ -10,7 +10,7 @@ from models.provider import ProviderType logger = logging.getLogger(__name__) -def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: +def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config if (moderation_config and moderation_config.enabled is True and 'openai' in hosting_configuration.provider_map diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index cdd03b85f1..48b0d8ba02 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,7 +1,7 @@ from typing import Optional from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -28,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: prompt_messages = [] model_mode = ModelMode.value_of(model_config.mode) @@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -110,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -199,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform): role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: EasyUIBasedModelConfigEntity) -> dict: + model_config: ModelConfigWithCredentialsEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 7fe8128a49..02e91d9112 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: + model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index faf1f888e2..ca0efb200c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,7 +4,7 @@ import os from typing import Optional from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( @@ -52,7 +52,7 @@ class SimplePromptTransform(PromptTransform): files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) -> \ + model_config: ModelConfigWithCredentialsEntity) -> \ tuple[list[PromptMessage], Optional[list[str]]]: model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.CHAT: @@ -81,7 +81,7 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, stops def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, inputs: dict, query: Optional[str] = None, @@ -162,7 +162,7 @@ class SimplePromptTransform(PromptTransform): context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) \ + model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -200,7 +200,7 @@ class SimplePromptTransform(PromptTransform): context: Optional[str], files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: EasyUIBasedModelConfigEntity) \ + model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py index 9b115bc696..f2c5d4ca33 100644 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ b/api/core/rag/retrieval/agent/llm_chain.py @@ -5,14 +5,14 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.rag.retrieval.agent.fake_llm import FakeLLM class LLMChain(LCLLMChain): - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") parameters: dict[str, Any] = {} diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py index 84e2b0228f..be24731d46 100644 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py @@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage from langchain.tools import BaseTool from pydantic import root_validator -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool @@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity class Config: """Configuration for this pydantic object.""" @@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): @classmethod def from_llm_and_tools( cls, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py index 700bf0c293..7035ec8e2f 100644 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.rag.retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -206,7 +206,7 @@ Thought: {agent_scratchpad} @classmethod def from_llm_and_tools( cls, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py index 749e603c5c..cb475bcffb 100644 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ b/api/core/rag/retrieval/agent_based_dataset_executor.py @@ -7,7 +7,7 @@ from langchain.callbacks.manager import Callbacks from langchain.tools import BaseTool from pydantic import BaseModel, Extra -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation @@ -22,9 +22,9 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr class AgentConfiguration(BaseModel): strategy: PlanningStrategy - model_config: EasyUIBasedModelConfigEntity + model_config: ModelConfigWithCredentialsEntity tools: list[BaseTool] - summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None + summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None memory: Optional[TokenBufferMemory] = None callbacks: Callbacks = None max_iterations: int = 6 diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 37581f1e92..395f2eb165 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,7 +3,7 @@ from typing import Optional, cast from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,7 +18,7 @@ from models.dataset import Dataset class DatasetRetrieval: def retrieve(self, tenant_id: str, - model_config: EasyUIBasedModelConfigEntity, + model_config: ModelConfigWithCredentialsEntity, config: DatasetEntity, query: str, invoke_from: InvokeFrom, diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 49eea603dc..77d1ab0822 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index d49e560a67..eca773f3b3 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,7 +9,7 @@ from models.provider import Provider @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 453194feb1..4e3c4e19f6 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -1,180 +1,71 @@ -import json from collections.abc import Generator from typing import Any, Union -from sqlalchemy import and_ - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_manager import EasyUIBasedAppManager +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.message_file_parser import MessageFileParser -from extensions.ext_database import db -from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message -from services.app_model_config_service import AppModelConfigService -from services.errors.app import MoreLikeThisDisabledError -from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import MessageNotExistsError +from models.model import Account, App, AppMode, EndUser class CompletionService: @classmethod def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, - invoke_from: InvokeFrom, streaming: bool = True, - is_model_config_override: bool = False) -> Union[dict, Generator]: - # is streaming mode - inputs = args['inputs'] - query = args['query'] - files = args['files'] if 'files' in args and args['files'] else [] - auto_generate_name = args['auto_generate_name'] \ - if 'auto_generate_name' in args else True - - if app_model.mode != AppMode.COMPLETION.value: - if not query: - raise ValueError('query is required') - - if query: - if not isinstance(query, str): - raise ValueError('query must be a string') - - query = query.replace('\x00', '') - - conversation_id = args['conversation_id'] if 'conversation_id' in args else None - - conversation = None - app_model_config_dict = None - if conversation_id: - conversation_filter = [ - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id, - Conversation.status == 'normal' - ] - - if isinstance(user, Account): - conversation_filter.append(Conversation.from_account_id == user.id) - else: - conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) - - conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() - - if not conversation: - raise ConversationNotExistsError() - - if conversation.status != 'normal': - raise ConversationCompletedError() - - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() - - if not app_model_config: - raise AppModelConfigBrokenError() - else: - if app_model.app_model_config_id is None: - raise AppModelConfigBrokenError() - - app_model_config = app_model.app_model_config - - if not app_model_config: - raise AppModelConfigBrokenError() - - if is_model_config_override: - if not isinstance(user, Account): - raise Exception("Only account can override model config") - - # validate config - app_model_config_dict = AppModelConfigService.validate_configuration( - tenant_id=app_model.tenant_id, - config=args['model_config'], - app_mode=AppMode.value_of(app_model.mode) - ) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict()) - if file_upload_entity: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_upload_entity, - user + invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]: + """ + App Completion + :param app_model: app model + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + if app_model.mode == AppMode.COMPLETION.value: + return CompletionAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.CHAT.value: + return ChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming + ) + elif app_model.mode == AppMode.AGENT_CHAT.value: + return AgentChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming ) else: - file_objs = [] - - application_manager = EasyUIBasedAppManager() - return application_manager.generate( - app_model=app_model, - app_model_config=app_model_config, - app_model_config_dict=app_model_config_dict, - user=user, - invoke_from=invoke_from, - inputs=inputs, - query=query, - files=file_objs, - conversation=conversation, - stream=streaming, - extras={ - "auto_generate_conversation_name": auto_generate_name - } - ) + raise ValueError('Invalid app mode') @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ -> Union[dict, Generator]: - if not user: - raise ValueError('user cannot be None') - - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() - - if not message: - raise MessageNotExistsError() - - current_app_model_config = app_model.app_model_config - more_like_this = current_app_model_config.more_like_this_dict - - if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: - raise MoreLikeThisDisabledError() - - app_model_config = message.app_model_config - model_dict = app_model_config.model_dict - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - app_model_config.model = json.dumps(model_dict) - - # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict()) - if file_upload_entity: - file_objs = message_file_parser.transform_message_files( - message.files, file_upload_entity - ) - else: - file_objs = [] - - application_manager = EasyUIBasedAppManager() - return application_manager.generate( + """ + Generate more like this + :param app_model: app model + :param user: user + :param message_id: message id + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + return CompletionAppGenerator().generate_more_like_this( app_model=app_model, - app_model_config=current_app_model_config, - app_model_config_dict=app_model_config.to_dict(), + message_id=message_id, user=user, invoke_from=invoke_from, - inputs=message.inputs, - query=message.query, - files=file_objs, - conversation=None, - stream=streaming, - extras={ - "auto_generate_conversation_name": False - } + stream=streaming ) - diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index b3061cc255..9d377cc466 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -8,9 +8,11 @@ from core.app.app_config.entities import ( FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, + VariableEntity, EasyUIBasedAppConfig, ) -from core.app.app_manager import EasyUIBasedAppManager +from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -87,8 +89,7 @@ class WorkflowConverter: new_app_mode = self._get_new_app_mode(app_model) # convert app model config - application_manager = EasyUIBasedAppManager() - app_config = application_manager.convert_to_app_config( + app_config = self._convert_to_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -190,6 +191,30 @@ class WorkflowConverter: return workflow + def _convert_to_app_config(self, app_model: App, + app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + app_mode = AppMode.value_of(app_model.mode) + if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_model.mode = AppMode.AGENT_CHAT.value + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + elif app_mode == AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config + ) + else: + raise ValueError("Invalid app mode") + + return app_config + def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: """ Convert to Start Node @@ -566,6 +591,6 @@ class WorkflowConverter: :return: """ return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 70f6070c6b..be9fe8d004 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform @@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): def test__get_chat_model_prompt_messages(): - model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' @@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): def test__get_completion_model_prompt_messages(): - model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) + model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' From 602bc67495d62334fc7796a0a6eaeacd19e33770 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 3 Mar 2024 04:18:51 +0800 Subject: [PATCH 097/200] lint fix --- api/core/agent/base_agent_runner.py | 3 ++- api/core/app/apps/agent_chat/app_generator.py | 9 +++++---- api/core/app/apps/agent_chat/app_runner.py | 3 +-- api/core/app/apps/base_app_generator.py | 2 +- api/core/app/apps/base_app_runner.py | 4 +++- api/core/app/apps/chat/app_generator.py | 9 +++++---- .../app/apps/completion/app_config_manager.py | 2 +- api/core/app/apps/completion/app_generator.py | 10 +++++----- .../app/apps/message_based_app_generator.py | 18 ++++++++++++------ api/core/app/entities/app_invoke_entities.py | 2 +- .../hosting_moderation/hosting_moderation.py | 2 +- api/core/app/generate_task_pipeline.py | 10 +++++++--- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/services/workflow/workflow_converter.py | 3 ++- 14 files changed, 47 insertions(+), 32 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ef530b9122..236a5d9cf7 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -10,8 +10,9 @@ from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + InvokeFrom, ModelConfigWithCredentialsEntity, - InvokeFrom, AgentChatAppGenerateEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 1ab456d822..d5dbdf0dd2 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,18 +1,19 @@ import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 6bae5e1648..27a473fb17 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -7,8 +7,7 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_runner import AppRunner -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \ - AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 65764021aa..750c6dae10 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,4 +1,4 @@ -from core.app.app_config.entities import VariableEntity, AppConfig +from core.app.app_config.entities import AppConfig, VariableEntity class BaseAppGenerator: diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index ee70f161a2..8de71d4bfb 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,8 +5,10 @@ from typing import Optional, Union, cast from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, + InvokeFrom, ModelConfigWithCredentialsEntity, - InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 712822f3a5..978ac9656b 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,18 +1,19 @@ import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 77a1443037..a82e68a337 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig class CompletionAppConfig(EasyUIBasedAppConfig): diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index d258a3bd9d..9355bae123 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,19 +1,19 @@ -import json import logging import threading import uuid -from typing import Union, Any, Generator +from collections.abc import Generator +from typing import Any, Union -from flask import current_app, Flask +from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, CompletionAppGenerateEntity +from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 783c6c6ee5..2fb609e615 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,21 +1,27 @@ import json import logging -from typing import Union, Generator, Optional +from collections.abc import Generator +from typing import Optional, Union from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom -from core.app.app_queue_manager import ConversationTaskStoppedException, AppQueueManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity, AppGenerateEntity, \ - CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) from core.app.generate_task_pipeline import GenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account -from models.model import Conversation, Message, AppMode, MessageFile, App, EndUser, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError +from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError logger = logging.getLogger(__name__) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9097345674..1c4f32b8f2 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig +from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.model_entities import AIModelEntity diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 7d555328db..ec316248a2 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,6 +1,6 @@ import logging -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity +from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation from core.model_runtime.entities.message_entities import PromptMessage diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/generate_task_pipeline.py index 926b0e128c..60dfc5cdad 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/generate_task_pipeline.py @@ -7,8 +7,12 @@ from typing import Optional, Union, cast from pydantic import BaseModel from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \ - AgentChatAppGenerateEntity +from core.app.entities.app_invoke_entities import ( + AgentChatAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + InvokeFrom, +) from core.app.entities.queue_entities import ( AnnotationReplyEvent, QueueAgentMessageEvent, @@ -40,7 +44,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 395f2eb165..ee72842326 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -3,7 +3,7 @@ from typing import Optional, cast from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9d377cc466..527c654381 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -4,11 +4,12 @@ from typing import Optional from core.app.app_config.entities import ( DatasetEntity, DatasetRetrieveConfigEntity, + EasyUIBasedAppConfig, ExternalDataVariableEntity, FileUploadEntity, ModelConfigEntity, PromptTemplateEntity, - VariableEntity, EasyUIBasedAppConfig, + VariableEntity, ) from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager From be709d4b844f870cb0457f0e58e5e74009405f9b Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:04:40 +0800 Subject: [PATCH 098/200] add AdvancedChatAppGenerateTaskPipeline --- api/core/app/app_queue_manager.py | 67 +++++- .../apps/advanced_chat/app_config_manager.py | 6 +- .../app/apps/advanced_chat/app_generator.py | 218 ++++++++++++++++++ api/core/app/apps/advanced_chat/app_runner.py | 103 +++++++++ api/core/app/apps/base_app_runner.py | 4 +- .../app/apps/message_based_app_generator.py | 38 +-- api/core/app/entities/queue_entities.py | 74 ++++-- api/core/workflow/workflow_engine_manager.py | 38 +++ .../deduct_quota_when_messaeg_created.py | 7 +- ...rsation_name_when_first_message_created.py | 3 +- ...vider_last_used_at_when_messaeg_created.py | 7 +- api/models/model.py | 6 +- api/models/workflow.py | 41 ++++ api/services/workflow_service.py | 19 +- 14 files changed, 570 insertions(+), 61 deletions(-) create mode 100644 api/core/app/apps/advanced_chat/app_generator.py create mode 100644 api/core/app/apps/advanced_chat/app_runner.py diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py index 4bd491269c..5655c8d979 100644 --- a/api/core/app/app_queue_manager.py +++ b/api/core/app/app_queue_manager.py @@ -8,19 +8,24 @@ from sqlalchemy.orm import DeclarativeMeta from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( - AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessage, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client @@ -97,18 +102,30 @@ class AppQueueManager: """ self._q.put(None) - def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: + def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: """ - Publish chunk message to channel + Publish llm chunk to channel - :param chunk: chunk + :param chunk: llm chunk :param pub_from: publish from :return: """ - self.publish(QueueMessageEvent( + self.publish(QueueLLMChunkEvent( chunk=chunk ), pub_from) + def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None: + """ + Publish text chunk to channel + + :param text: text + :param pub_from: publish from + :return: + """ + self.publish(QueueTextChunkEvent( + text=text + ), pub_from) + def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: """ Publish agent chunk message to channel @@ -146,7 +163,7 @@ class AppQueueManager: :param pub_from: publish from :return: """ - self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) + self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: """ @@ -158,6 +175,42 @@ class AppQueueManager: self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) self.stop_listen() + def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None: + """ + Publish workflow started + :param workflow_run_id: workflow run id + :param pub_from: publish from + :return: + """ + self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from) + + def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None: + """ + Publish workflow finished + :param workflow_run_id: workflow run id + :param pub_from: publish from + :return: + """ + self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from) + + def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: + """ + Publish node started + :param workflow_node_execution_id: workflow node execution id + :param pub_from: publish from + :return: + """ + self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) + + def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: + """ + Publish node finished + :param workflow_node_execution_id: workflow node execution id + :param pub_from: publish from + :return: + """ + self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) + def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: """ Publish agent thought diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index 72ba4c33d4..3ac26ebe80 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ -from typing import Optional from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager @@ -12,7 +11,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import App, AppMode, Conversation +from models.model import App, AppMode from models.workflow import Workflow @@ -26,8 +25,7 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config(cls, app_model: App, - workflow: Workflow, - conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig: + workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_config = AdvancedChatAppConfig( diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py new file mode 100644 index 0000000000..ca2f400547 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -0,0 +1,218 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Any, Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, Conversation, EndUser, Message + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppGenerator(MessageBasedAppGenerator): + def generate(self, app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not args.get('query'): + raise ValueError('query is required') + + query = args['query'] + if not isinstance(query, str): + raise ValueError('query must be a string') + + query = query.replace('\x00', '') + inputs = args['inputs'] + + extras = { + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + } + + # get conversation + conversation = None + if args.get('conversation_id'): + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + + # get workflow + workflow_engine_manager = WorkflowEngineManager() + if invoke_from == InvokeFrom.DEBUGGER: + workflow = workflow_engine_manager.get_draft_workflow(app_model=app_model) + else: + workflow = workflow_engine_manager.get_published_workflow(app_model=app_model) + + if not workflow: + raise ValueError('Workflow not initialized') + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=conversation.id if conversation else None, + inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + query=query, + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras=extras + ) + + # init generate records + ( + conversation, + message + ) = self._init_generate_records(application_generate_entity, conversation) + + # init queue manager + queue_manager = AppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager, + 'conversation_id': conversation.id, + 'message_id': message.id, + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation_id: conversation ID + :param message_id: message ID + :return: + """ + with flask_app.app_context(): + try: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message + ) + + try: + return generate_task_pipeline.process(stream=stream) + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py new file mode 100644 index 0000000000..0d701ae224 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -0,0 +1,103 @@ +import logging +from typing import cast + +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, +) +from core.moderation.base import ModerationException +from extensions.ext_database import db +from models.model import App, Conversation, Message + +logger = logging.getLogger(__name__) + + +class AdvancedChatAppRunner(AppRunner): + """ + AdvancedChat Application Runner + """ + + def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(AdvancedChatAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=app_config.tenant_id, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query, + ) + except ModerationException as e: + # TODO + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + + # TODO + self.direct_output( + queue_manager=queue_manager, + app_generate_entity=application_generate_entity, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # check hosting moderation + # TODO + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + # todo RUN WORKFLOW \ No newline at end of file diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 8de71d4bfb..4e099c9ae1 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -187,7 +187,7 @@ class AppRunner: if stream: index = 0 for token in text: - queue_manager.publish_chunk_message(LLMResultChunk( + queue_manager.publish_llm_chunk(LLMResultChunk( model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -261,7 +261,7 @@ class AppRunner: usage = None for result in invoke_result: if not agent: - queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER) else: queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 2fb609e615..dab72bd6d6 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,14 +8,15 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, AppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom, ) -from core.app.generate_task_pipeline import GenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models.account import Account @@ -31,7 +32,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response(self, application_generate_entity: Union[ ChatAppGenerateEntity, CompletionAppGenerateEntity, - AgentChatAppGenerateEntity + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity ], queue_manager: AppQueueManager, conversation: Conversation, @@ -47,7 +49,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :return: """ # init generate task pipeline - generate_task_pipeline = GenerateTaskPipeline( + generate_task_pipeline = EasyUIBasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, @@ -114,7 +116,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): application_generate_entity: Union[ ChatAppGenerateEntity, CompletionAppGenerateEntity, - AgentChatAppGenerateEntity + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity ], conversation: Optional[Conversation] = None) \ -> tuple[Conversation, Message]: @@ -135,10 +138,19 @@ class MessageBasedAppGenerator(BaseAppGenerator): from_source = 'console' account_id = application_generate_entity.user_id - override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: - override_model_configs = app_config.app_model_config_dict + if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): + app_model_config_id = None + override_model_configs = None + model_provider = None + model_id = None + else: + app_model_config_id = app_config.app_model_config_id + model_provider = application_generate_entity.model_config.provider + model_id = application_generate_entity.model_config.model + override_model_configs = None + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ + and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + override_model_configs = app_config.app_model_config_dict # get conversation introduction introduction = self._get_conversation_introduction(application_generate_entity) @@ -146,9 +158,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): if not conversation: conversation = Conversation( app_id=app_config.app_id, - app_model_config_id=app_config.app_model_config_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, + app_model_config_id=app_model_config_id, + model_provider=model_provider, + model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, name='New conversation', @@ -167,8 +179,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): message = Message( app_id=app_config.app_id, - model_provider=application_generate_entity.model_config.provider, - model_id=application_generate_entity.model_config.model, + model_provider=model_provider, + model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, conversation_id=conversation.id, inputs=application_generate_entity.inputs, diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index c1f8fb7e89..25bdd7d9e3 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -10,14 +10,19 @@ class QueueEvent(Enum): """ QueueEvent enum """ - MESSAGE = "message" + LLM_CHUNK = "llm_chunk" + TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" - MESSAGE_REPLACE = "message-replace" - MESSAGE_END = "message-end" - RETRIEVER_RESOURCES = "retriever-resources" - ANNOTATION_REPLY = "annotation-reply" - AGENT_THOUGHT = "agent-thought" - MESSAGE_FILE = "message-file" + MESSAGE_REPLACE = "message_replace" + MESSAGE_END = "message_end" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_FINISHED = "workflow_finished" + NODE_STARTED = "node_started" + NODE_FINISHED = "node_finished" + RETRIEVER_RESOURCES = "retriever_resources" + ANNOTATION_REPLY = "annotation_reply" + AGENT_THOUGHT = "agent_thought" + MESSAGE_FILE = "message_file" ERROR = "error" PING = "ping" STOP = "stop" @@ -30,13 +35,22 @@ class AppQueueEvent(BaseModel): event: QueueEvent -class QueueMessageEvent(AppQueueEvent): +class QueueLLMChunkEvent(AppQueueEvent): """ - QueueMessageEvent entity + QueueLLMChunkEvent entity """ - event = QueueEvent.MESSAGE + event = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + +class QueueTextChunkEvent(AppQueueEvent): + """ + QueueTextChunkEvent entity + """ + event = QueueEvent.TEXT_CHUNK + chunk_text: str + + class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity @@ -61,9 +75,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): retriever_resources: list[dict] -class AnnotationReplyEvent(AppQueueEvent): +class QueueAnnotationReplyEvent(AppQueueEvent): """ - AnnotationReplyEvent entity + QueueAnnotationReplyEvent entity """ event = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -76,6 +90,38 @@ class QueueMessageEndEvent(AppQueueEvent): event = QueueEvent.MESSAGE_END llm_result: LLMResult + +class QueueWorkflowStartedEvent(AppQueueEvent): + """ + QueueWorkflowStartedEvent entity + """ + event = QueueEvent.WORKFLOW_STARTED + workflow_run_id: str + + +class QueueWorkflowFinishedEvent(AppQueueEvent): + """ + QueueWorkflowFinishedEvent entity + """ + event = QueueEvent.WORKFLOW_FINISHED + workflow_run_id: str + + +class QueueNodeStartedEvent(AppQueueEvent): + """ + QueueNodeStartedEvent entity + """ + event = QueueEvent.NODE_STARTED + workflow_node_execution_id: str + + +class QueueNodeFinishedEvent(AppQueueEvent): + """ + QueueNodeFinishedEvent entity + """ + event = QueueEvent.NODE_FINISHED + workflow_node_execution_id: str + class QueueAgentThoughtEvent(AppQueueEvent): """ @@ -84,13 +130,15 @@ class QueueAgentThoughtEvent(AppQueueEvent): event = QueueEvent.AGENT_THOUGHT agent_thought_id: str + class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ event = QueueEvent.MESSAGE_FILE message_file_id: str - + + class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index e69de29bb2..f7955a87e8 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -0,0 +1,38 @@ +from typing import Optional + +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowEngineManager: + def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get draft workflow + """ + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() + + # return draft workflow + return workflow + + def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + """ + Get published workflow + """ + if not app_model.workflow_id: + return None + + # fetch published workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id + ).first() + + # return published workflow + return workflow diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 77d1ab0822..53cbb2ecdc 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -1,4 +1,4 @@ -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created from extensions.ext_database import db @@ -8,7 +8,10 @@ from models.provider import Provider, ProviderType @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return model_config = application_generate_entity.model_config provider_model_bundle = model_config.provider_model_bundle diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index f5f3ba2540..31535bf4ef 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -1,6 +1,7 @@ from core.llm_generator.llm_generator import LLMGenerator from events.message_event import message_was_created from extensions.ext_database import db +from models.model import AppMode @message_was_created.connect @@ -15,7 +16,7 @@ def handle(sender, **kwargs): auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) if auto_generate_conversation_name and is_first_message: - if conversation.mode == 'chat': + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: return diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index eca773f3b3..ae983cc5d1 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,6 +1,6 @@ from datetime import datetime -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider @@ -9,7 +9,10 @@ from models.provider import Provider @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get('application_generate_entity') + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, diff --git a/api/models/model.py b/api/models/model.py index f8f9a0a3cd..c579c3dee8 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -451,10 +451,10 @@ class Conversation(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - app_model_config_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) + app_model_config_id = db.Column(UUID, nullable=True) + model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - model_id = db.Column(db.String(255), nullable=False) + model_id = db.Column(db.String(255), nullable=True) mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) diff --git a/api/models/workflow.py b/api/models/workflow.py index f9c906b85c..2540d33402 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -272,6 +272,10 @@ class WorkflowRun(db.Model): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def outputs_dict(self): + return self.outputs if not self.outputs else json.loads(self.outputs) + class WorkflowNodeExecutionTriggeredFrom(Enum): """ @@ -294,6 +298,28 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): raise ValueError(f'invalid workflow node execution triggered from value {value}') +class WorkflowNodeExecutionStatus(Enum): + """ + Workflow Node Execution Status Enum + """ + RUNNING = 'running' + SUCCEEDED = 'succeeded' + FAILED = 'failed' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow node execution status value {value}') + + class WorkflowNodeExecution(db.Model): """ Workflow Node Execution @@ -387,6 +413,21 @@ class WorkflowNodeExecution(db.Model): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def inputs_dict(self): + return self.inputs if not self.inputs else json.loads(self.inputs) + + @property + def outputs_dict(self): + return self.outputs if not self.outputs else json.loads(self.outputs) + + @property + def process_data_dict(self): + return self.process_data if not self.process_data else json.loads(self.process_data) + + @property + def execution_metadata_dict(self): + return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) class WorkflowAppLog(db.Model): """ diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c9efd056ff..13ea67d343 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -4,6 +4,7 @@ from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -21,15 +22,10 @@ class WorkflowService: """ Get draft workflow """ - # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow_engine_manager = WorkflowEngineManager() # return draft workflow - return workflow + return workflow_engine_manager.get_draft_workflow(app_model=app_model) def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ @@ -38,15 +34,10 @@ class WorkflowService: if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow_engine_manager = WorkflowEngineManager() # return published workflow - return workflow + return workflow_engine_manager.get_published_workflow(app_model=app_model) def sync_draft_workflow(self, app_model: App, graph: dict, From e9004a06a563b92a45df16dbadd99a3855378cfc Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:04:46 +0800 Subject: [PATCH 099/200] lint fix --- .../advanced_chat/generate_task_pipeline.py | 563 ++++++++++++++++++ .../easy_ui_based_generate_task_pipeline.py} | 43 +- 2 files changed, 585 insertions(+), 21 deletions(-) create mode 100644 api/core/app/apps/advanced_chat/generate_task_pipeline.py rename api/core/app/{generate_task_pipeline.py => apps/easy_ui_based_generate_task_pipeline.py} (95%) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py new file mode 100644 index 0000000000..d443435fc1 --- /dev/null +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -0,0 +1,563 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel + +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, +) +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.tools.tool_file_manager import ToolFileManager +from events.message_event import message_was_created +from extensions.ext_database import db +from models.model import Conversation, Message, MessageFile +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus +from services.annotation_service import AppAnnotationService + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + answer: str = "" + metadata: dict = {} + + +class AdvancedChatAppGenerateTaskPipeline: + """ + AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param conversation: conversation + :param message: message + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._conversation = conversation + self._message = message + self._task_state = TaskState( + usage=LLMUsage.empty_usage() + ) + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + + def process(self, stream: bool) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + if stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + # Save message + self._save_message() + + response = { + 'event': 'message', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'mode': self._conversation.mode, + 'answer': self._task_state.answer, + 'metadata': {}, + 'created_at': int(self._message.created_at.timestamp()) + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: + if workflow_node_execution.node_type == 'llm': # todo use enum + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueWorkflowFinishedEvent): + workflow_run = self._get_workflow_run(event.workflow_run_id) + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_price': workflow_run.total_price, + 'currency': workflow_run.currency, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': self._task_state.answer, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(replace_response) + + # Save message + self._save_message() + + response = { + 'event': 'message_end', + 'task_id': self._application_generate_entity.task_id, + 'id': self._message.id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + } + + if self._task_state.metadata: + response['metadata'] = self._get_response_metadata() + + yield self._yield_response(response) + elif isinstance(event, QueueRetrieverResourcesEvent): + self._task_state.metadata['retriever_resources'] = event.retriever_resources + elif isinstance(event, QueueAnnotationReplyEvent): + annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) + if annotation: + account = annotation.account + self._task_state.metadata['annotation_reply'] = { + 'id': annotation.id, + 'account': { + 'id': annotation.account_id, + 'name': account.name if account else 'Dify user' + } + } + + self._task_state.answer = annotation.content + elif isinstance(event, QueueMessageFileEvent): + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + if message_file: + response = { + 'event': 'message_file', + 'conversation_id': self._conversation.id, + 'id': message_file.id, + 'type': message_file.type, + 'belongs_to': message_file.belongs_to or 'user', + 'url': url + } + + yield self._yield_response(response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.chunk_text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'message_replace', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': event.text, + 'created_at': int(self._message.created_at.timestamp()) + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Get workflow run. + :param workflow_run_id: workflow run id + :return: + """ + return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: + """ + Get workflow node execution. + :param workflow_node_execution_id: workflow node execution id + :return: + """ + return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + + def _save_message(self) -> None: + """ + Save message. + :return: + """ + self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + + self._message.answer = self._task_state.answer + self._message.provider_response_latency = time.perf_counter() - self._start_at + + if self._task_state.metadata and self._task_state.metadata.get('usage'): + usage = LLMUsage(**self._task_state.metadata['usage']) + + self._message.message_tokens = usage.prompt_tokens + self._message.message_unit_price = usage.prompt_unit_price + self._message.message_price_unit = usage.prompt_price_unit + self._message.answer_tokens = usage.completion_tokens + self._message.answer_unit_price = usage.completion_unit_price + self._message.answer_price_unit = usage.completion_price_unit + self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.total_price = usage.total_price + self._message.currency = usage.currency + + db.session.commit() + + message_was_created.send( + self._message, + application_generate_entity=self._application_generate_entity, + conversation=self._conversation, + is_first_message=self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras + ) + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'message', + 'id': self._message.id, + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + 'conversation_id': self._conversation.id, + 'answer': text, + 'created_at': int(self._message.created_at.timestamp()) + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'message_id': self._message.id, + **data + } + + def _get_response_metadata(self) -> dict: + """ + Get response metadata by invoke from. + :return: + """ + metadata = {} + + # show_retrieve_source + if 'retriever_resources' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] + else: + metadata['retriever_resources'] = [] + for resource in self._task_state.metadata['retriever_resources']: + metadata['retriever_resources'].append({ + 'segment_id': resource['segment_id'], + 'position': resource['position'], + 'document_name': resource['document_name'], + 'score': resource['score'], + 'content': resource['content'], + }) + # show annotation reply + if 'annotation_reply' in self._task_state.metadata: + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] + + # show usage + if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + metadata['usage'] = self._task_state.metadata['usage'] + + return metadata + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + on_message_replace_func=self._queue_manager.publish_message_replace + ) diff --git a/api/core/app/generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py similarity index 95% rename from api/core/app/generate_task_pipeline.py rename to api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 60dfc5cdad..80596668b8 100644 --- a/api/core/app/generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( - AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, QueueErrorEvent, + QueueLLMChunkEvent, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, @@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -58,9 +59,9 @@ class TaskState(BaseModel): metadata: dict = {} -class GenerateTaskPipeline: +class EasyUIBasedGenerateTaskPipeline: """ - GenerateTaskPipeline is a class that generate stream output and state management for Application. + EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: Union[ @@ -79,12 +80,13 @@ class GenerateTaskPipeline: :param message: message """ self._application_generate_entity = application_generate_entity + self._model_config = application_generate_entity.model_config self._queue_manager = queue_manager self._conversation = conversation self._message = message self._task_state = TaskState( llm_result=LLMResult( - model=self._application_generate_entity.model_config.model, + model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), usage=LLMUsage.empty_usage() @@ -119,7 +121,7 @@ class GenerateTaskPipeline: raise self._handle_error(event) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -136,7 +138,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -193,7 +195,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -219,7 +221,7 @@ class GenerateTaskPipeline: if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: - model_config = self._application_generate_entity.model_config + model_config = self._model_config model = model_config.model model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -272,7 +274,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: replace_response['conversation_id'] = self._conversation.id yield self._yield_response(replace_response) @@ -287,7 +289,7 @@ class GenerateTaskPipeline: 'message_id': self._message.id, } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id if self._task_state.metadata: @@ -296,7 +298,7 @@ class GenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueRetrieverResourcesEvent): self._task_state.metadata['retriever_resources'] = event.retriever_resources - elif isinstance(event, AnnotationReplyEvent): + elif isinstance(event, QueueAnnotationReplyEvent): annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account @@ -334,7 +336,7 @@ class GenerateTaskPipeline: 'message_files': agent_thought.files } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -365,12 +367,12 @@ class GenerateTaskPipeline: 'url': url } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) - elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): + elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -383,7 +385,7 @@ class GenerateTaskPipeline: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_chunk_message(LLMResultChunk( + self._queue_manager.publish_llm_chunk(LLMResultChunk( model=self._task_state.llm_result.model, prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( @@ -411,7 +413,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id yield self._yield_response(response) @@ -452,8 +454,7 @@ class GenerateTaskPipeline: conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in [ AppMode.AGENT_CHAT, - AppMode.CHAT, - AppMode.ADVANCED_CHAT + AppMode.CHAT ] and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras ) @@ -473,7 +474,7 @@ class GenerateTaskPipeline: 'created_at': int(self._message.created_at.timestamp()) } - if self._conversation.mode == 'chat': + if self._conversation.mode != AppMode.COMPLETION.value: response['conversation_id'] = self._conversation.id return response @@ -583,7 +584,7 @@ class GenerateTaskPipeline: :return: """ prompts = [] - if self._application_generate_entity.model_config.mode == 'chat': + if self._model_config.mode == ModelMode.CHAT.value: for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: role = 'user' From d9b8a938c6a68ea4cdbdbcb9c01333e356eafe08 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:05:47 +0800 Subject: [PATCH 100/200] use enum instead --- api/core/app/apps/advanced_chat/generate_task_pipeline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index d443435fc1..2aa649afea 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.NodeEntities import NodeType from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageFile @@ -111,7 +112,7 @@ class AdvancedChatAppGenerateTaskPipeline: elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == 'llm': # todo use enum + if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict usage_dict = outputs.get('usage', {}) self._task_state.metadata['usage'] = usage_dict @@ -201,7 +202,7 @@ class AdvancedChatAppGenerateTaskPipeline: elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == 'llm': # todo use enum + if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict usage_dict = outputs.get('usage', {}) self._task_state.metadata['usage'] = usage_dict From 75559bcbf90168ab4cf5f0b04881b0e4b01d6835 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 02:06:27 +0800 Subject: [PATCH 101/200] replace block type to node type --- api/core/workflow/entities/NodeEntities.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/NodeEntities.py index d72b000dfb..80471cc702 100644 --- a/api/core/workflow/entities/NodeEntities.py +++ b/api/core/workflow/entities/NodeEntities.py @@ -19,14 +19,14 @@ class NodeType(Enum): VARIABLE_ASSIGNER = 'variable-assigner' @classmethod - def value_of(cls, value: str) -> 'BlockType': + def value_of(cls, value: str) -> 'NodeType': """ - Get value of given block type. + Get value of given node type. - :param value: block type value - :return: block type + :param value: node type value + :return: node type """ - for block_type in cls: - if block_type.value == value: - return block_type - raise ValueError(f'invalid block type value {value}') + for node_type in cls: + if node_type.value == value: + return node_type + raise ValueError(f'invalid node type value {value}') From df809ff435c155510121c2e083a477b9fc13e28e Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:21:24 +0800 Subject: [PATCH 102/200] add get default node config --- api/controllers/console/app/app.py | 2 +- api/controllers/console/app/workflow.py | 35 ++++++++- .../advanced_chat/generate_task_pipeline.py | 2 +- .../{NodeEntities.py => node_entities.py} | 0 api/core/workflow/nodes/base_node.py | 12 ++++ api/core/workflow/nodes/code/__init__.py | 0 api/core/workflow/nodes/code/code_node.py | 64 +++++++++++++++++ .../workflow/nodes/direct_answer/__init__.py | 0 .../nodes/direct_answer/direct_answer_node.py | 5 ++ api/core/workflow/nodes/end/end_node.py | 5 ++ .../workflow/nodes/http_request/__init__.py | 0 .../nodes/http_request/http_request_node.py | 5 ++ api/core/workflow/nodes/if_else/__init__.py | 0 .../workflow/nodes/if_else/if_else_node.py | 5 ++ .../nodes/knowledge_retrieval/__init__.py | 0 .../knowledge_retrieval_node.py | 5 ++ api/core/workflow/nodes/llm/__init__.py | 0 api/core/workflow/nodes/llm/llm_node.py | 40 +++++++++++ .../nodes/question_classifier/__init__.py | 0 .../question_classifier_node.py | 19 +++++ api/core/workflow/nodes/start/__init__.py | 0 api/core/workflow/nodes/start/start_node.py | 5 ++ .../nodes/template_transform/__init__.py | 0 .../template_transform_node.py | 25 +++++++ api/core/workflow/nodes/tool/__init__.py | 0 api/core/workflow/nodes/tool/tool_node.py | 5 ++ .../nodes/variable_assigner/__init__.py | 0 .../variable_assigner_node.py | 5 ++ api/core/workflow/workflow_engine_manager.py | 60 ++++++++++++++++ api/services/app_service.py | 2 +- api/services/workflow/defaults.py | 72 ------------------- api/services/workflow/workflow_converter.py | 2 +- api/services/workflow_service.py | 19 ++++- 33 files changed, 314 insertions(+), 80 deletions(-) rename api/core/workflow/entities/{NodeEntities.py => node_entities.py} (100%) create mode 100644 api/core/workflow/nodes/base_node.py create mode 100644 api/core/workflow/nodes/code/__init__.py create mode 100644 api/core/workflow/nodes/code/code_node.py create mode 100644 api/core/workflow/nodes/direct_answer/__init__.py create mode 100644 api/core/workflow/nodes/direct_answer/direct_answer_node.py create mode 100644 api/core/workflow/nodes/http_request/__init__.py create mode 100644 api/core/workflow/nodes/http_request/http_request_node.py create mode 100644 api/core/workflow/nodes/if_else/__init__.py create mode 100644 api/core/workflow/nodes/if_else/if_else_node.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/__init__.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py create mode 100644 api/core/workflow/nodes/llm/__init__.py create mode 100644 api/core/workflow/nodes/llm/llm_node.py create mode 100644 api/core/workflow/nodes/question_classifier/__init__.py create mode 100644 api/core/workflow/nodes/question_classifier/question_classifier_node.py create mode 100644 api/core/workflow/nodes/start/__init__.py create mode 100644 api/core/workflow/nodes/start/start_node.py create mode 100644 api/core/workflow/nodes/template_transform/__init__.py create mode 100644 api/core/workflow/nodes/template_transform/template_transform_node.py create mode 100644 api/core/workflow/nodes/tool/__init__.py create mode 100644 api/core/workflow/nodes/tool/tool_node.py create mode 100644 api/core/workflow/nodes/variable_assigner/__init__.py create mode 100644 api/core/workflow/nodes/variable_assigner/variable_assigner_node.py delete mode 100644 api/services/workflow/defaults.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7b2411b96f..66bcbccefe 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -34,7 +34,7 @@ class AppListApi(Resource): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent', 'channel', 'all'], default='all', location='args', required=False) + parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 54585d8519..5dfb2b1443 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,3 +1,5 @@ +import json + from flask_restful import Resource, marshal_with, reqparse from controllers.console import api @@ -147,7 +149,7 @@ class PublishedWorkflowApi(Resource): } -class DefaultBlockConfigApi(Resource): +class DefaultBlockConfigsApi(Resource): @setup_required @login_required @account_initialization_required @@ -161,6 +163,34 @@ class DefaultBlockConfigApi(Resource): return workflow_service.get_default_block_configs() +class DefaultBlockConfigApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App, block_type: str): + """ + Get default block config + """ + parser = reqparse.RequestParser() + parser.add_argument('q', type=str, location='args') + args = parser.parse_args() + + filters = None + if args.get('q'): + try: + filters = json.loads(args.get('q')) + except json.JSONDecodeError: + raise ValueError('Invalid filters') + + # Get default block configs + workflow_service = WorkflowService() + return workflow_service.get_default_block_config( + node_type=block_type, + filters=filters + ) + + class ConvertToWorkflowApi(Resource): @setup_required @login_required @@ -188,5 +218,6 @@ api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs/:block_type') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2aa649afea..77e779a0ad 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -30,7 +30,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.NodeEntities import NodeType +from core.workflow.entities.node_entities import NodeType from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageFile diff --git a/api/core/workflow/entities/NodeEntities.py b/api/core/workflow/entities/node_entities.py similarity index 100% rename from api/core/workflow/entities/NodeEntities.py rename to api/core/workflow/entities/node_entities.py diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py new file mode 100644 index 0000000000..665338af08 --- /dev/null +++ b/api/core/workflow/nodes/base_node.py @@ -0,0 +1,12 @@ +from typing import Optional + + +class BaseNode: + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return {} diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py new file mode 100644 index 0000000000..7e69f91d11 --- /dev/null +++ b/api/core/workflow/nodes/code/code_node.py @@ -0,0 +1,64 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class CodeNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + if filters and filters.get("code_language") == "javascript": + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "javascript", + "code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" + "\n if (true) {\n resolve({\n \"result\": arg1 + arg2" + "\n });\n } else {\n reject(\"e\");\n }\n });\n}", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + } + + return { + "type": "code", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + }, + { + "variable": "arg2", + "value_selector": [] + } + ], + "code_language": "python3", + "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " + "+ arg2\n }", + "outputs": [ + { + "variable": "result", + "variable_type": "number" + } + ] + } + } diff --git a/api/core/workflow/nodes/direct_answer/__init__.py b/api/core/workflow/nodes/direct_answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py new file mode 100644 index 0000000000..c6013974b8 --- /dev/null +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class DirectAnswerNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index e69de29bb2..f9aea89af7 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class EndNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py new file mode 100644 index 0000000000..5be25a9834 --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class HttpRequestNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py new file mode 100644 index 0000000000..98a5c85db2 --- /dev/null +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class IfElseNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py new file mode 100644 index 0000000000..c6dd624921 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class KnowledgeRetrievalNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py new file mode 100644 index 0000000000..1c7277e942 --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -0,0 +1,40 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class LLMNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + { + "role": "system", + "text": "You are a helpful AI assistant." + } + ] + }, + "completion_model": { + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + }, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" + }, + "stop": ["Human:"] + } + } + } + } diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py new file mode 100644 index 0000000000..f676b6372a --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -0,0 +1,19 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class QuestionClassifierNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "question-classifier", + "config": { + "instructions": "" # TODO + } + } diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py new file mode 100644 index 0000000000..8cce655728 --- /dev/null +++ b/api/core/workflow/nodes/start/start_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class StartNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py new file mode 100644 index 0000000000..2bf26e307e --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -0,0 +1,25 @@ +from typing import Optional + +from core.workflow.nodes.base_node import BaseNode + + +class TemplateTransformNode(BaseNode): + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return { + "type": "template-transform", + "config": { + "variables": [ + { + "variable": "arg1", + "value_selector": [] + } + ], + "template": "{{ arg1 }}" + } + } diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py new file mode 100644 index 0000000000..b805a53d2f --- /dev/null +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class ToolNode(BaseNode): + pass diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py new file mode 100644 index 0000000000..231a26a661 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/variable_assigner_node.py @@ -0,0 +1,5 @@ +from core.workflow.nodes.base_node import BaseNode + + +class VariableAssignerNode(BaseNode): + pass diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index f7955a87e8..73e92d5e89 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,37 @@ from typing import Optional +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db from models.model import App from models.workflow import Workflow +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.DIRECT_ANSWER: DirectAnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, +} + class WorkflowEngineManager: def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: @@ -36,3 +64,35 @@ class WorkflowEngineManager: # return published workflow return workflow + + def get_default_configs(self) -> list[dict]: + """ + Get default block configs + """ + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append({ + 'type': node_type.value, + 'config': default_config + }) + + return default_block_configs + + def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_class = node_classes.get(node_type) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config diff --git a/api/services/app_service.py b/api/services/app_service.py index f1d0e3df19..6011b6a667 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -35,7 +35,7 @@ class AppService: filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) elif args['mode'] == 'chat': filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent': + elif args['mode'] == 'agent-chat': filters.append(App.mode == AppMode.AGENT_CHAT.value) elif args['mode'] == 'channel': filters.append(App.mode == AppMode.CHANNEL.value) diff --git a/api/services/workflow/defaults.py b/api/services/workflow/defaults.py deleted file mode 100644 index 67804fa4eb..0000000000 --- a/api/services/workflow/defaults.py +++ /dev/null @@ -1,72 +0,0 @@ -# default block config -default_block_configs = [ - { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant." - } - ] - }, - "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "prompt": { - "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant:" - }, - "stop": ["Human:"] - } - } - } - }, - { - "type": "code", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], - "code_language": "python3", - "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " - "+ arg2\n }", - "outputs": [ - { - "variable": "result", - "variable_type": "number" - } - ] - } - }, - { - "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } - }, - { - "type": "question-classifier", - "config": { - "instructions": "" # TODO - } - } -] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 527c654381..4c7e4db47a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,7 +18,7 @@ from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform -from core.workflow.entities.NodeEntities import NodeType +from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.end.entities import EndNodeOutputType from events.app_event import app_was_created from extensions.ext_database import db diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 13ea67d343..396845d16a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -4,6 +4,7 @@ from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account @@ -121,12 +122,26 @@ class WorkflowService: # return new workflow return workflow - def get_default_block_configs(self) -> dict: + def get_default_block_configs(self) -> list[dict]: """ Get default block configs """ # return default block config - return default_block_configs + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_configs() + + def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + """ + Get default config of node. + :param node_type: node type + :param filters: filter by node config parameters. + :return: + """ + node_type = NodeType.value_of(node_type) + + # return default block config + workflow_engine_manager = WorkflowEngineManager() + return workflow_engine_manager.get_default_config(node_type, filters) def convert_to_workflow(self, app_model: App, account: Account) -> App: """ From de40422205ea941e562d29a501ce8782c999cffa Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:21:30 +0800 Subject: [PATCH 103/200] lint fix --- api/services/workflow_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 396845d16a..0be0783ae0 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,7 +10,6 @@ from extensions.ext_database import db from models.account import Account from models.model import App, AppMode from models.workflow import Workflow, WorkflowType -from services.workflow.defaults import default_block_configs from services.workflow.workflow_converter import WorkflowConverter From 242fcf0145683481d6a8ebce1258fe796472744c Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 13:32:59 +0800 Subject: [PATCH 104/200] fix typo --- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- api/core/app/apps/base_app_runner.py | 2 +- api/core/app/apps/chat/app_runner.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8b444ef3be..ad1e6e610d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -134,7 +134,7 @@ class CotAgentRunner(BaseAgentRunner): input=query ) - # recalc llm max tokens + # recale llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 30e5cdd694..3c7e55e293 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -107,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): messages_ids=message_file_ids ) - # recalc llm max tokens + # recale llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4e099c9ae1..dda240d778 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -84,7 +84,7 @@ class AppRunner: return rest_tokens - def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, + def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 57aca9d3e6..bce4606f21 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -189,7 +189,7 @@ class ChatAppRunner(AppRunner): return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recale_llm_max_tokens( + self.recalc_llm_max_tokens( model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index c5b8ca6c9a..d67d485e1d 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -149,7 +149,7 @@ class CompletionAppRunner(AppRunner): return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recale_llm_max_tokens( + self.recalc_llm_max_tokens( model_config=application_generate_entity.model_config, prompt_messages=prompt_messages ) From 3086893ee76e56fdd2155b4270139805d0388c77 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 14:15:17 +0800 Subject: [PATCH 105/200] fix typo --- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index ad1e6e610d..8b444ef3be 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -134,7 +134,7 @@ class CotAgentRunner(BaseAgentRunner): input=query ) - # recale llm max tokens + # recalc llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3c7e55e293..30e5cdd694 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -107,7 +107,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): messages_ids=message_file_ids ) - # recale llm max tokens + # recalc llm max tokens self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( From df753e84a3b8239cf58f04689610ceee6ff4bccd Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:27 +0800 Subject: [PATCH 106/200] fix workflow api return --- api/controllers/console/app/workflow.py | 91 +++++++-- .../app/apps/advanced_chat/app_generator.py | 16 +- api/core/app/apps/advanced_chat/app_runner.py | 178 +++++++++++++----- api/core/app/entities/queue_entities.py | 1 + api/core/workflow/entities/node_entities.py | 9 + api/core/workflow/entities/variable_pool.py | 82 ++++++++ api/core/workflow/nodes/base_node.py | 37 ++++ api/core/workflow/workflow_engine_manager.py | 34 +++- api/fields/workflow_fields.py | 4 +- api/fields/workflow_run_fields.py | 20 +- api/models/workflow.py | 8 + api/services/workflow_service.py | 39 +++- 12 files changed, 434 insertions(+), 85 deletions(-) create mode 100644 api/core/workflow/entities/variable_pool.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5dfb2b1443..9ee6ca9dbd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,18 +1,28 @@ import json +import logging +from typing import Generator +from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import NotFound, InternalServerError +import services from controllers.console import api -from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.error import DraftWorkflowNotExist, ConversationCompletedError from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields +from libs.helper import uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.workflow_service import WorkflowService +logger = logging.getLogger(__name__) + + class DraftWorkflowApi(Resource): @setup_required @login_required @@ -59,23 +69,80 @@ class DraftWorkflowApi(Resource): } -class DraftWorkflowRunApi(Resource): +class AdvancedChatDraftWorkflowRunApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) def post(self, app_model: App): """ Run draft workflow """ - # TODO - workflow_service = WorkflowService() - workflow_service.run_draft_workflow(app_model=app_model, account=current_user) + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json', default='') + parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() - # TODO - return { - "result": "success" - } + workflow_service = WorkflowService() + try: + response = workflow_service.run_advanced_chat_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +class DraftWorkflowRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + def post(self, app_model: App): + """ + Run draft workflow + """ + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + args = parser.parse_args() + + workflow_service = WorkflowService() + + try: + response = workflow_service.run_draft_workflow( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER + ) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') class WorkflowTaskStopApi(Resource): @@ -214,10 +281,12 @@ class ConvertToWorkflowApi(Resource): api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') +api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs/:block_type') +api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' + '/') api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ca2f400547..918fd4566e 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -16,18 +16,19 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message +from models.workflow import Workflow logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): def generate(self, app_model: App, + workflow: Workflow, user: Union[Account, EndUser], - args: Any, + args: dict, invoke_from: InvokeFrom, stream: bool = True) \ -> Union[dict, Generator]: @@ -35,6 +36,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): Generate App response. :param app_model: App + :param workflow: Workflow :param user: account or end user :param args: request args :param invoke_from: invoke from source @@ -59,16 +61,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if args.get('conversation_id'): conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) - # get workflow - workflow_engine_manager = WorkflowEngineManager() - if invoke_from == InvokeFrom.DEBUGGER: - workflow = workflow_engine_manager.get_draft_workflow(app_model=app_model) - else: - workflow = workflow_engine_manager.get_published_workflow(app_model=app_model) - - if not workflow: - raise ValueError('Workflow not initialized') - # parse files files = args['files'] if 'files' in args and args['files'] else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 0d701ae224..f853f88af4 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,15 +1,20 @@ import logging +import time from typing import cast from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, InvokeFrom, ) +from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationException +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.account import Account +from models.model import App, Conversation, Message, EndUser logger = logging.getLogger(__name__) @@ -38,66 +43,151 @@ class AdvancedChatAppRunner(AppRunner): if not app_record: raise ValueError("App not found") + workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + inputs = application_generate_entity.inputs query = application_generate_entity.query files = application_generate_entity.files # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs, + query=query + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=message, + query=query, + queue_manager=queue_manager, + app_generate_entity=application_generate_entity + ): + return + + # fetch user + if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: + user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() + else: + user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + result_generator = workflow_engine_manager.run_workflow( + app_model=app_record, + workflow=workflow, + user=user, + user_inputs=inputs, + system_inputs={ + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION: conversation.id, + } + ) + + for result in result_generator: + # todo handle workflow and node event + pass + + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: dict, + query: str) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :param query: query + :return: + """ try: # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( app_id=app_record.id, - tenant_id=app_config.tenant_id, - app_generate_entity=application_generate_entity, + tenant_id=app_generate_entity.app_config.tenant_id, + app_generate_entity=app_generate_entity, inputs=inputs, query=query, ) except ModerationException as e: - # TODO - self.direct_output( + self._stream_output( queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION ) - return + return True - if query: - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from - ) + return False - if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER - ) - - # TODO - self.direct_output( - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - prompt_messages=prompt_messages, - text=annotation_reply.content, - stream=application_generate_entity.stream - ) - return - - # check hosting moderation - # TODO - hosting_moderation_result = self.check_hosting_moderation( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - prompt_messages=prompt_messages + def handle_annotation_reply(self, app_record: App, + message: Message, + query: str, + queue_manager: AppQueueManager, + app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + """ + Handle annotation reply + :param app_record: app record + :param message: message + :param query: query + :param queue_manager: application queue manager + :param app_generate_entity: application generate entity + """ + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=app_generate_entity.user_id, + invoke_from=app_generate_entity.invoke_from ) - if hosting_moderation_result: - return + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) - # todo RUN WORKFLOW \ No newline at end of file + self._stream_output( + queue_manager=queue_manager, + text=annotation_reply.content, + stream=app_generate_entity.stream, + stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + ) + return True + + return False + + def _stream_output(self, queue_manager: AppQueueManager, + text: str, + stream: bool, + stopped_by: QueueStopEvent.StopBy) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :param stream: stream + :return: + """ + if stream: + index = 0 + for token in text: + queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER) + index += 1 + time.sleep(0.01) + + queue_manager.publish( + QueueStopEvent(stopped_by=stopped_by), + PublishFrom.APPLICATION_MANAGER + ) + queue_manager.stop_listen() diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 25bdd7d9e3..e5c6a8eff9 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -165,6 +165,7 @@ class QueueStopEvent(AppQueueEvent): USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" + INPUT_MODERATION = "input-moderation" event = QueueEvent.STOP stopped_by: StopBy diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 80471cc702..18f0f7746c 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -30,3 +30,12 @@ class NodeType(Enum): if node_type.value == value: return node_type raise ValueError(f'invalid node type value {value}') + + +class SystemVariable(Enum): + """ + System Variables. + """ + QUERY = 'query' + FILES = 'files' + CONVERSATION = 'conversation' diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py new file mode 100644 index 0000000000..eefee88c07 --- /dev/null +++ b/api/core/workflow/entities/variable_pool.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Optional, Union, Any + +from core.workflow.entities.node_entities import SystemVariable + +VariableValue = Union[str, int, float, dict, list] + + +class ValueType(Enum): + """ + Value Type Enum + """ + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY = "array" + FILE = "file" + + +class VariablePool: + variables_mapping = {} + + def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: + # system variables + # for example: + # { + # 'query': 'abc', + # 'files': [] + # } + for system_variable, value in system_variables.items(): + self.append_variable('sys', [system_variable.value], value) + + def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: + """ + Append variable + :param node_id: node id + :param variable_key_list: variable key list, like: ['result', 'text'] + :param value: value + :return: + """ + if node_id not in self.variables_mapping: + self.variables_mapping[node_id] = {} + + variable_key_list_hash = hash(tuple(variable_key_list)) + + self.variables_mapping[node_id][variable_key_list_hash] = value + + def get_variable_value(self, variable_selector: list[str], + target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]: + """ + Get variable + :param variable_selector: include node_id and variables + :param target_value_type: target value type + :return: + """ + if len(variable_selector) < 2: + raise ValueError('Invalid value selector') + + node_id = variable_selector[0] + if node_id not in self.variables_mapping: + return None + + # fetch variable keys, pop node_id + variable_key_list = variable_selector[1:] + + variable_key_list_hash = hash(tuple(variable_key_list)) + + value = self.variables_mapping[node_id].get(variable_key_list_hash) + + if target_value_type: + if target_value_type == ValueType.STRING: + return str(value) + elif target_value_type == ValueType.NUMBER: + return int(value) + elif target_value_type == ValueType.OBJECT: + if not isinstance(value, dict): + raise ValueError('Invalid value type: object') + elif target_value_type == ValueType.ARRAY: + if not isinstance(value, list): + raise ValueError('Invalid value type: array') + + return value diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 665338af08..a2751b346f 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,7 +1,44 @@ +from abc import abstractmethod from typing import Optional +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool + class BaseNode: + _node_type: NodeType + + def __int__(self, node_config: dict) -> None: + self._node_config = node_config + + @abstractmethod + def run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + if variable_pool is None and run_args is None: + raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") + + return self._run( + variable_pool=variable_pool, + run_args=run_args + ) + + @abstractmethod + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + raise NotImplementedError + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 73e92d5e89..5914bfc152 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Optional, Union, Generator +from core.memory.token_buffer_memory import TokenBufferMemory from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -14,7 +15,8 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db -from models.model import App +from models.account import Account +from models.model import App, EndUser, Conversation from models.workflow import Workflow node_classes = { @@ -56,13 +58,20 @@ class WorkflowEngineManager: return None # fetch published workflow by workflow_id + return self.get_workflow(app_model, app_model.workflow_id) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id workflow = db.session.query(Workflow).filter( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id + Workflow.id == workflow_id ).first() - # return published workflow + # return workflow return workflow def get_default_configs(self) -> list[dict]: @@ -96,3 +105,20 @@ class WorkflowEngineManager: return None return default_config + + def run_workflow(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> Generator: + """ + Run workflow + :param app_model: App instance + :param workflow: Workflow instance + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + # TODO + pass diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index bcb2c318c6..9919a440e8 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,8 +5,8 @@ from libs.helper import TimestampField workflow_fields = { 'id': fields.String, - 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'), - 'features': fields.Nested(simple_account_fields, attribute='features_dict'), + 'graph': fields.Raw(attribute='graph_dict'), + 'features': fields.Raw(attribute='features_dict'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 37751bc70f..85c9c2d2b2 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -22,10 +22,10 @@ workflow_run_for_list_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.String, - "inputs": fields.String, + "graph": fields.Raw(attribute='graph_dict'), + "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.String, + "outputs": fields.Raw(attribute='outputs_dict'), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -49,10 +49,10 @@ workflow_run_detail_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.String, - "inputs": fields.String, + "graph": fields.Raw(attribute='graph_dict'), + "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.String, + "outputs": fields.Raw(attribute='outputs_dict'), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -73,13 +73,13 @@ workflow_run_node_execution_fields = { "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.String, - "process_data": fields.String, - "outputs": fields.String, + "inputs": fields.Raw(attribute='inputs_dict'), + "process_data": fields.Raw(attribute='process_data_dict'), + "outputs": fields.Raw(attribute='outputs_dict'), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.String, + "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), "created_at": TimestampField, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index 2540d33402..32ff26196c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -272,6 +272,14 @@ class WorkflowRun(db.Model): return EndUser.query.get(self.created_by) \ if created_by_role == CreatedByRole.END_USER else None + @property + def graph_dict(self): + return self.graph if not self.graph else json.loads(self.graph) + + @property + def inputs_dict(self): + return self.inputs if not self.inputs else json.loads(self.inputs) + @property def outputs_dict(self): return self.outputs if not self.outputs else json.loads(self.outputs) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0be0783ae0..37f5c16bec 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,14 +1,16 @@ import json from datetime import datetime -from typing import Optional +from typing import Optional, Union, Any, Generator from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode +from models.model import App, AppMode, EndUser from models.workflow import Workflow, WorkflowType from services.workflow.workflow_converter import WorkflowConverter @@ -142,6 +144,39 @@ class WorkflowService: workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) + def run_advanced_chat_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + """ + Run advanced chatbot draft workflow + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = AdvancedChatAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response + + def run_draft_workflow(self, app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom) -> Union[dict, Generator]: + # TODO + pass + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From c8a1f923f53f720e84a941456284ee3f3de167c7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:35 +0800 Subject: [PATCH 107/200] lint fix --- api/controllers/console/app/workflow.py | 7 +++---- api/core/app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/advanced_chat/app_runner.py | 5 +++-- api/core/workflow/workflow_engine_manager.py | 6 +++--- api/services/workflow_service.py | 3 ++- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9ee6ca9dbd..6e77f50e65 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,14 +1,14 @@ import json import logging -from typing import Generator +from collections.abc import Generator from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse -from werkzeug.exceptions import NotFound, InternalServerError +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import DraftWorkflowNotExist, ConversationCompletedError +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -19,7 +19,6 @@ from libs.login import current_user, login_required from models.model import App, AppMode from services.workflow_service import WorkflowService - logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 918fd4566e..937f95679a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -2,7 +2,7 @@ import logging import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Union from flask import Flask, current_app from pydantic import ValidationError diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index f853f88af4..02d22072df 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -6,7 +6,8 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, InvokeFrom, + AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationException @@ -14,7 +15,7 @@ from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account -from models.model import App, Conversation, Message, EndUser +from models.model import App, Conversation, EndUser, Message logger = logging.getLogger(__name__) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 5914bfc152..8a23048705 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,6 @@ -from typing import Optional, Union, Generator +from collections.abc import Generator +from typing import Optional, Union -from core.memory.token_buffer_memory import TokenBufferMemory from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -16,7 +16,7 @@ from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode from extensions.ext_database import db from models.account import Account -from models.model import App, EndUser, Conversation +from models.model import App, EndUser from models.workflow import Workflow node_classes = { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 37f5c16bec..2c1b6eb819 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,7 @@ import json +from collections.abc import Generator from datetime import datetime -from typing import Optional, Union, Any, Generator +from typing import Optional, Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator From 1a86e79d4a6b32ed818f3278e0377dab17060aba Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 17:23:40 +0800 Subject: [PATCH 108/200] lint fix --- api/core/workflow/entities/variable_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index eefee88c07..e84044dede 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional, Union, Any +from typing import Any, Optional, Union from core.workflow.entities.node_entities import SystemVariable From 75f1355d4c742399f247a7dd0737512b6f1741db Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 23:34:23 +0800 Subject: [PATCH 109/200] add few workflow run codes --- api/commands.py | 2 +- api/core/app/app_config/entities.py | 1 + api/core/app/apps/advanced_chat/app_runner.py | 7 +- api/core/callback_handler/__init__.py | 0 .../std_out_callback_handler.py | 157 ------------------ .../workflow_event_trigger_callback.py | 45 +++++ api/core/workflow/callbacks/__init__.py | 0 api/core/workflow/callbacks/base_callback.py | 33 ++++ .../entities/base_node_data_entities.py | 7 + api/core/workflow/nodes/base_node.py | 43 ++--- api/core/workflow/nodes/start/entities.py | 27 +++ api/core/workflow/nodes/start/start_node.py | 19 ++- api/core/workflow/workflow_engine_manager.py | 96 ++++++++++- 13 files changed, 254 insertions(+), 183 deletions(-) create mode 100644 api/core/callback_handler/__init__.py delete mode 100644 api/core/callback_handler/std_out_callback_handler.py create mode 100644 api/core/callback_handler/workflow_event_trigger_callback.py create mode 100644 api/core/workflow/callbacks/__init__.py create mode 100644 api/core/workflow/callbacks/base_callback.py create mode 100644 api/core/workflow/entities/base_node_data_entities.py create mode 100644 api/core/workflow/nodes/start/entities.py diff --git a/api/commands.py b/api/commands.py index 73325620ee..376a394d1e 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation +from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index e155dc1c4d..6a521dfcc5 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -112,6 +112,7 @@ class VariableEntity(BaseModel): max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None + hint: Optional[str] = None class ExternalDataVariableEntity(BaseModel): diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 02d22072df..920adcfb79 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -10,12 +10,14 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import QueueStopEvent +from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message +from models.workflow import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -83,13 +85,16 @@ class AdvancedChatAppRunner(AppRunner): result_generator = workflow_engine_manager.run_workflow( app_model=app_record, workflow=workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, user=user, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, - } + }, + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) for result in result_generator: diff --git a/api/core/callback_handler/__init__.py b/api/core/callback_handler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py deleted file mode 100644 index 1f95471afb..0000000000 --- a/api/core/callback_handler/std_out_callback_handler.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import sys -from typing import Any, Optional, Union - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult - - -class DifyStdOutCallbackHandler(BaseCallbackHandler): - """Callback Handler that prints to std out.""" - - def __init__(self, color: Optional[str] = None) -> None: - """Initialize callback handler.""" - self.color = color - - def on_chat_model_start( - self, - serialized: dict[str, Any], - messages: list[list[BaseMessage]], - **kwargs: Any - ) -> Any: - print_text("\n[on_chat_model_start]\n", color='blue') - for sub_messages in messages: - for sub_message in sub_messages: - print_text(str(sub_message) + "\n", color='blue') - - def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any - ) -> None: - """Print out the prompts.""" - print_text("\n[on_llm_start]\n", color='blue') - print_text(prompts[0] + "\n", color='blue') - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Do nothing.""" - print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str( - response.llm_output) + "\n", color='blue') - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') - - def on_chain_start( - self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - chain_type = serialized['id'][-1] - print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - - def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink') - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_start] " + str(serialized), color='yellow') - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - tool = action.tool - tool_input = action.tool_input - try: - action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 - thought = action.log[:action_name_position].strip() if action.log else '' - except ValueError: - thought = '' - - log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" - print_text("\n[on_agent_action]\n" + log + "\n", color='green') - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - print_text("\n[on_tool_end]\n", color='yellow') - if observation_prefix: - print_text(f"\n{observation_prefix}") - print_text(output, color='yellow') - if llm_prefix: - print_text(f"\n{llm_prefix}") - print_text("\n") - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow') - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end) - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - @property - def ignore_chat_model(self) -> bool: - """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' - - -class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): - """Callback handler for streaming. Only works with LLMs that support streaming.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - sys.stdout.write(token) - sys.stdout.flush() diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py new file mode 100644 index 0000000000..2f81f27426 --- /dev/null +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -0,0 +1,45 @@ +from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager): + self._queue_manager = queue_manager + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish_workflow_started( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish_workflow_finished( + workflow_run_id=workflow_run.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish_node_started( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish_node_finished( + workflow_node_execution_id=workflow_node_execution.id, + pub_from=PublishFrom.TASK_PIPELINE + ) diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_callback.py new file mode 100644 index 0000000000..a564af498c --- /dev/null +++ b/api/core/workflow/callbacks/base_callback.py @@ -0,0 +1,33 @@ +from abc import abstractmethod + +from models.workflow import WorkflowRun, WorkflowNodeExecution + + +class BaseWorkflowCallback: + @abstractmethod + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + raise NotImplementedError diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py new file mode 100644 index 0000000000..32b93ea094 --- /dev/null +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -0,0 +1,7 @@ +from abc import ABC + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + pass diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a2751b346f..a95a232ae6 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,32 +1,21 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool class BaseNode: _node_type: NodeType + _node_data_cls: Type[BaseNodeData] - def __int__(self, node_config: dict) -> None: - self._node_config = node_config + def __init__(self, config: dict) -> None: + self._node_id = config.get("id") + if not self._node_id: + raise ValueError("Node ID is required.") - @abstractmethod - def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: - """ - Run node - :param variable_pool: variable pool - :param run_args: run args - :return: - """ - if variable_pool is None and run_args is None: - raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - - return self._run( - variable_pool=variable_pool, - run_args=run_args - ) + self._node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, @@ -39,6 +28,22 @@ class BaseNode: """ raise NotImplementedError + def run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node entry + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + if variable_pool is None and run_args is None: + raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") + + return self._run( + variable_pool=variable_pool, + run_args=run_args + ) + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py new file mode 100644 index 0000000000..25b27cf192 --- /dev/null +++ b/api/core/workflow/nodes/start/entities.py @@ -0,0 +1,27 @@ +from typing import Optional + +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType + + +class StartNodeData(BaseNodeData): + """ + - title (string) 节点标题 + - desc (string) optional 节点描述 + - type (string) 节点类型,固定为 start + - variables (array[object]) 表单变量列表 + - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) + - label (string) 控件展示标签名 + - variable (string) 变量 key + - max_length (int) 最大长度,适用于 text-input 和 paragraph + - default (string) optional 默认值 + - required (bool) optional是否必填,默认 false + - hint (string) optional 提示信息 + - options (array[string]) 选项值(仅 select 可用) + """ + type: str = NodeType.START.value + + title: str + desc: Optional[str] = None + variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 8cce655728..014a146c93 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,22 @@ +from typing import Type, Optional + +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.start.entities import StartNodeData class StartNode(BaseNode): - pass + _node_type = NodeType.START + _node_data_cls = StartNodeData + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> dict: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + pass + diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8a23048705..afa4dbb321 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,8 @@ +import json from collections.abc import Generator from typing import Optional, Union +from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -17,7 +19,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole node_classes = { NodeType.START: StartNode, @@ -108,17 +110,103 @@ class WorkflowEngineManager: def run_workflow(self, app_model: App, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, - system_inputs: Optional[dict] = None) -> Generator: + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Generator: """ Run workflow :param app_model: App instance :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks + :return: + """ + # fetch workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + # init workflow run + workflow_run = self._init_workflow_run( + workflow=workflow, + triggered_from=triggered_from, + user=user, + user_inputs=user_inputs, + system_inputs=system_inputs + ) + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started(workflow_run) + + pass + + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :return: """ - # TODO - pass + try: + db.session.begin() + + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .for_update() \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **system_inputs}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by_id=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + except: + db.session.rollback() + raise + + return workflow_run + + def _get_entry_node(self, graph: dict) -> Optional[StartNode]: + """ + Get entry node + :param graph: workflow graph + :return: + """ + nodes = graph.get('nodes') + if not nodes: + return None + + for node_config in nodes.items(): + if node_config.get('type') == NodeType.START.value: + return StartNode(config=node_config) + + return None From bc4edbfc2bb5526a062248589c9c1f3aee623fe1 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 4 Mar 2024 23:34:28 +0800 Subject: [PATCH 110/200] lint fix --- api/core/callback_handler/workflow_event_trigger_callback.py | 2 +- api/core/workflow/callbacks/base_callback.py | 2 +- api/core/workflow/nodes/base_node.py | 4 ++-- api/core/workflow/nodes/start/start_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index 2f81f27426..e1d2413534 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,6 +1,6 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.workflow.callbacks.base_callback import BaseWorkflowCallback -from models.workflow import WorkflowRun, WorkflowNodeExecution +from models.workflow import WorkflowNodeExecution, WorkflowRun class WorkflowEventTriggerCallback(BaseWorkflowCallback): diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_callback.py index a564af498c..76fe4d96d5 100644 --- a/api/core/workflow/callbacks/base_callback.py +++ b/api/core/workflow/callbacks/base_callback.py @@ -1,6 +1,6 @@ from abc import abstractmethod -from models.workflow import WorkflowRun, WorkflowNodeExecution +from models.workflow import WorkflowNodeExecution, WorkflowRun class BaseWorkflowCallback: diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a95a232ae6..6f28a3f104 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional, Type +from typing import Optional from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool class BaseNode: _node_type: NodeType - _node_data_cls: Type[BaseNodeData] + _node_data_cls: type[BaseNodeData] def __init__(self, config: dict) -> None: self._node_id = config.get("id") diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 014a146c93..e218cced3d 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,4 @@ -from typing import Type, Optional +from typing import Optional from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index afa4dbb321..3ad36fe1d2 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -19,7 +19,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole +from models.workflow import CreatedByRole, Workflow, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom node_classes = { NodeType.START: StartNode, From a5de7b10f36d4854c70630cf19c956854c1eefef Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 5 Mar 2024 17:35:05 +0800 Subject: [PATCH 111/200] update ruff check --- web/.husky/pre-commit | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index dfd6ec0209..1f8ae9a8d3 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -24,7 +24,21 @@ done if $api_modified; then echo "Running Ruff linter on api module" - ./dev/reformat + + # python style checks rely on `ruff` in path + if ! command -v ruff &> /dev/null; then + echo "Installing Ruff ..." + pip install ruff + fi + + ruff check ./api + result=$? + + if [ $result -ne 0 ]; then + echo "Please run 'dev/reformat' to fix the fixable linting errors." + fi + + exit $result fi if $web_modified; then From 79a10e97295e2ae92ee819904b5f97b3f7b1092b Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 13:26:14 +0800 Subject: [PATCH 112/200] add updated_at to sync workflow api --- api/controllers/console/app/workflow.py | 7 +- api/core/app/apps/advanced_chat/app_runner.py | 7 +- .../entities/base_node_data_entities.py | 6 +- .../workflow/entities/workflow_entities.py | 16 ++ api/core/workflow/nodes/base_node.py | 24 ++- api/core/workflow/nodes/start/entities.py | 4 - api/core/workflow/nodes/start/start_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 184 +++++++++++++++++- api/libs/helper.py | 2 +- web/.husky/pre-commit | 12 +- 10 files changed, 233 insertions(+), 31 deletions(-) create mode 100644 api/core/workflow/entities/workflow_entities.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6e77f50e65..4f8df6bcec 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -14,7 +14,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields -from libs.helper import uuid_value +from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.workflow_service import WorkflowService @@ -56,7 +56,7 @@ class DraftWorkflowApi(Resource): args = parser.parse_args() workflow_service = WorkflowService() - workflow_service.sync_draft_workflow( + workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args.get('graph'), features=args.get('features'), @@ -64,7 +64,8 @@ class DraftWorkflowApi(Resource): ) return { - "result": "success" + "result": "success", + "updated_at": TimestampField().format(workflow.updated_at) } diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 920adcfb79..898091f52c 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -82,7 +82,7 @@ class AdvancedChatAppRunner(AppRunner): # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() - result_generator = workflow_engine_manager.run_workflow( + workflow_engine_manager.run_workflow( app_model=app_record, workflow=workflow, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING @@ -97,11 +97,6 @@ class AdvancedChatAppRunner(AppRunner): callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) - for result in result_generator: - # todo handle workflow and node event - pass - - def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 32b93ea094..afa6ddff04 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -1,7 +1,11 @@ from abc import ABC +from typing import Optional from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): - pass + type: str + + title: str + desc: Optional[str] = None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py new file mode 100644 index 0000000000..21126caf30 --- /dev/null +++ b/api/core/workflow/entities/workflow_entities.py @@ -0,0 +1,16 @@ +from decimal import Decimal + +from core.workflow.entities.variable_pool import VariablePool +from models.workflow import WorkflowNodeExecution, WorkflowRun + + +class WorkflowRunState: + workflow_run: WorkflowRun + start_at: float + variable_pool: VariablePool + + total_tokens: int = 0 + total_price: Decimal = Decimal(0) + currency: str = "USD" + + workflow_node_executions: list[WorkflowNodeExecution] = [] diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6f28a3f104..314dfb8f22 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,21 +1,25 @@ from abc import abstractmethod from typing import Optional +from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool class BaseNode: - _node_type: NodeType _node_data_cls: type[BaseNodeData] + _node_type: NodeType + + node_id: str + node_data: BaseNodeData def __init__(self, config: dict) -> None: - self._node_id = config.get("id") - if not self._node_id: + self.node_id = config.get("id") + if not self.node_id: raise ValueError("Node ID is required.") - self._node_data = self._node_data_cls(**config.get("data", {})) + self.node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, @@ -29,11 +33,13 @@ class BaseNode: raise NotImplementedError def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> dict: """ Run node entry :param variable_pool: variable pool :param run_args: run args + :param callbacks: callbacks :return: """ if variable_pool is None and run_args is None: @@ -52,3 +58,11 @@ class BaseNode: :return: """ return {} + + @property + def node_type(self) -> NodeType: + """ + Get node type + :return: + """ + return self._node_type diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 25b27cf192..64687db042 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -22,6 +20,4 @@ class StartNodeData(BaseNodeData): """ type: str = NodeType.START.value - title: str - desc: Optional[str] = None variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index e218cced3d..74d8541436 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -7,8 +7,8 @@ from core.workflow.nodes.start.entities import StartNodeData class StartNode(BaseNode): - _node_type = NodeType.START _node_data_cls = StartNodeData + node_type = NodeType.START def _run(self, variable_pool: Optional[VariablePool] = None, run_args: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3ad36fe1d2..0ec93dd4b2 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,12 @@ import json -from collections.abc import Generator +import time from typing import Optional, Union from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_entities import WorkflowRunState +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode @@ -19,7 +22,16 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl from extensions.ext_database import db from models.account import Account from models.model import App, EndUser -from models.workflow import CreatedByRole, Workflow, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) node_classes = { NodeType.START: StartNode, @@ -114,7 +126,7 @@ class WorkflowEngineManager: user: Union[Account, EndUser], user_inputs: dict, system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> Generator: + callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param app_model: App instance @@ -140,11 +152,66 @@ class WorkflowEngineManager: system_inputs=system_inputs ) + # init workflow run state + workflow_run_state = WorkflowRunState( + workflow_run=workflow_run, + start_at=time.perf_counter(), + variable_pool=VariablePool( + system_variables=system_inputs, + ) + ) + if callbacks: for callback in callbacks: callback.on_workflow_run_started(workflow_run) - pass + # fetch start node + start_node = self._get_entry_node(graph) + if not start_node: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error='Start node not found in workflow graph', + callbacks=callbacks + ) + return + + try: + predecessor_node = None + current_node = start_node + while True: + # run workflow + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=current_node, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if current_node.node_type == NodeType.END: + break + + # todo fetch next node until end node finished or no next node + current_node = None + + if not current_node: + break + + predecessor_node = current_node + # or max steps 30 reached + # or max execution time 10min reached + except Exception as e: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error=str(e), + callbacks=callbacks + ) + return + + # workflow run success + self._workflow_run_success( + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, @@ -184,7 +251,7 @@ class WorkflowEngineManager: status=WorkflowRunStatus.RUNNING.value, created_by_role=(CreatedByRole.ACCOUNT.value if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by_id=user.id + created_by=user.id ) db.session.add(workflow_run) @@ -195,6 +262,33 @@ class WorkflowEngineManager: return workflow_run + def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, + error: str, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run_state: workflow run state + :param error: error message + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + workflow_run.status = WorkflowRunStatus.FAILED.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at + workflow_run.total_tokens = workflow_run_state.total_tokens + workflow_run.total_price = workflow_run_state.total_price + workflow_run.currency = workflow_run_state.currency + workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_finished(workflow_run) + + return workflow_run + def _get_entry_node(self, graph: dict) -> Optional[StartNode]: """ Get entry node @@ -210,3 +304,83 @@ class WorkflowEngineManager: return StartNode(config=node_config) return None + + def _run_workflow_node(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + # init workflow node execution + start_at = time.perf_counter() + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run_state=workflow_run_state, + node=node, + predecessor_node=predecessor_node, + ) + + # add to workflow node executions + workflow_run_state.workflow_node_executions.append(workflow_node_execution) + + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool, + callbacks=callbacks + ) + except Exception as e: + # node run failed + self._workflow_node_execution_failed( + workflow_node_execution=workflow_node_execution, + error=str(e), + callbacks=callbacks + ) + raise + + # node run success + self._workflow_node_execution_success( + workflow_node_execution=workflow_node_execution, + result=node_run_result, + callbacks=callbacks + ) + + return workflow_node_execution + + def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run_state: workflow run state + :param node: current node + :param predecessor_node: predecessor node if exists + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + index=len(workflow_run_state.workflow_node_executions) + 1, + node_id=node.node_id, + node_type=node.node_type.value, + title=node.node_data.title, + type=node.node_type.value, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_started(workflow_node_execution) + + return workflow_node_execution diff --git a/api/libs/helper.py b/api/libs/helper.py index a35f4ad471..3eb14c50f0 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -15,7 +15,7 @@ def run(script): class TimestampField(fields.Raw): - def format(self, value): + def format(self, value) -> int: return int(value.timestamp()) diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 1f8ae9a8d3..4bc7fb77ab 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -31,14 +31,16 @@ if $api_modified; then pip install ruff fi - ruff check ./api - result=$? + ruff check ./api || status=$? - if [ $result -ne 0 ]; then + status=${status:-0} + + + if [ $status -ne 0 ]; then + echo "Ruff linter on api module error, exit code: $status" echo "Please run 'dev/reformat' to fix the fixable linting errors." + exit 1 fi - - exit $result fi if $web_modified; then From dd50deaa438dc264ebfcbaf30e9fab30824ea681 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 13:45:01 +0800 Subject: [PATCH 113/200] fix audio voice arg --- api/services/audio_service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a658487f8..d013a51c3e 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -64,7 +64,8 @@ class AudioService: return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: str, streaming: bool, end_user: Optional[str] = None): + def transcript_tts(cls, app_model: App, text: str, streaming: bool, + voice: Optional[str] = None, end_user: Optional[str] = None): if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: workflow = app_model.workflow if workflow is None: @@ -74,14 +75,14 @@ class AudioService: if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') + voice = features_dict['text_to_speech'].get('voice') if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict if not text_to_speech_dict.get('enabled'): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice'), + voice = text_to_speech_dict.get('voice') if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( From 7d28fe8ea5d0b295a4d2e0073c8593fcc86f1870 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 17:43:42 +0800 Subject: [PATCH 114/200] completed workflow engine main logic --- api/core/app/apps/advanced_chat/app_runner.py | 3 +- .../advanced_chat/generate_task_pipeline.py | 2 - .../workflow_event_trigger_callback.py | 11 +- ..._callback.py => base_workflow_callback.py} | 8 + api/core/workflow/entities/node_entities.py | 21 ++ .../workflow/entities/workflow_entities.py | 9 +- api/core/workflow/nodes/base_node.py | 48 ++- api/core/workflow/workflow_engine_manager.py | 334 +++++++++++++++--- api/fields/workflow_run_fields.py | 6 - api/models/workflow.py | 4 - 10 files changed, 366 insertions(+), 80 deletions(-) rename api/core/workflow/callbacks/{base_callback.py => base_workflow_callback.py} (85%) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 898091f52c..c5ffa80165 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner): # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( - app_model=app_record, workflow=workflow, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, @@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)], ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 77e779a0ad..cfeb46f05a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline: 'error': workflow_run.error, 'elapsed_time': workflow_run.elapsed_time, 'total_tokens': workflow_run.total_tokens, - 'total_price': workflow_run.total_price, - 'currency': workflow_run.currency, 'total_steps': workflow_run.total_steps, 'created_at': int(workflow_run.created_at.timestamp()), 'finished_at': int(workflow_run.finished_at.timestamp()) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index e1d2413534..80dabc7548 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,5 +1,5 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom -from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): workflow_node_execution_id=workflow_node_execution.id, pub_from=PublishFrom.TASK_PIPELINE ) + + def on_text_chunk(self, text: str) -> None: + """ + Publish text chunk + """ + self._queue_manager.publish_text_chunk( + text=text, + pub_from=PublishFrom.TASK_PIPELINE + ) diff --git a/api/core/workflow/callbacks/base_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py similarity index 85% rename from api/core/workflow/callbacks/base_callback.py rename to api/core/workflow/callbacks/base_workflow_callback.py index 76fe4d96d5..3425b2b03c 100644 --- a/api/core/workflow/callbacks/base_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -31,3 +31,11 @@ class BaseWorkflowCallback: Workflow node execute finished """ raise NotImplementedError + + @abstractmethod + def on_text_chunk(self, text: str) -> None: + """ + Publish text chunk + """ + raise NotImplementedError + diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 18f0f7746c..af539692ef 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,4 +1,9 @@ from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from models.workflow import WorkflowNodeExecutionStatus class NodeType(Enum): @@ -39,3 +44,19 @@ class SystemVariable(Enum): QUERY = 'query' FILES = 'files' CONVERSATION = 'conversation' + + +class NodeRunResult(BaseModel): + """ + Node Run Result. + """ + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING + + inputs: Optional[dict] = None # node inputs + process_data: Optional[dict] = None # process data + outputs: Optional[dict] = None # node outputs + metadata: Optional[dict] = None # node metadata + + edge_source_handle: Optional[str] = None # source handle id of node with multiple branches + + error: Optional[str] = None # error message if status is failed diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 21126caf30..0d78e4c4f1 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,5 +1,3 @@ -from decimal import Decimal - from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -10,7 +8,10 @@ class WorkflowRunState: variable_pool: VariablePool total_tokens: int = 0 - total_price: Decimal = Decimal(0) - currency: str = "USD" workflow_node_executions: list[WorkflowNodeExecution] = [] + + def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: + self.workflow_run = workflow_run + self.start_at = start_at + self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 314dfb8f22..efffdfae1a 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,10 +1,11 @@ from abc import abstractmethod from typing import Optional -from core.workflow.callbacks.base_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from models.workflow import WorkflowNodeExecutionStatus class BaseNode: @@ -13,17 +14,23 @@ class BaseNode: node_id: str node_data: BaseNodeData + node_run_result: Optional[NodeRunResult] = None - def __init__(self, config: dict) -> None: + stream_output_supported: bool = False + callbacks: list[BaseWorkflowCallback] + + def __init__(self, config: dict, + callbacks: list[BaseWorkflowCallback] = None) -> None: self.node_id = config.get("id") if not self.node_id: raise ValueError("Node ID is required.") self.node_data = self._node_data_cls(**config.get("data", {})) + self.callbacks = callbacks or [] @abstractmethod def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node :param variable_pool: variable pool @@ -33,22 +40,41 @@ class BaseNode: raise NotImplementedError def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node entry :param variable_pool: variable pool :param run_args: run args - :param callbacks: callbacks :return: """ if variable_pool is None and run_args is None: raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - return self._run( - variable_pool=variable_pool, - run_args=run_args - ) + try: + result = self._run( + variable_pool=variable_pool, + run_args=run_args + ) + except Exception as e: + # process unhandled exception + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + self.node_run_result = result + return result + + def publish_text_chunk(self, text: str) -> None: + """ + Publish text chunk + :param text: chunk text + :return: + """ + if self.stream_output_supported: + if self.callbacks: + for callback in self.callbacks: + callback.on_text_chunk(text) @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0ec93dd4b2..908b684930 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,10 +1,11 @@ import json import time +from datetime import datetime from typing import Optional, Union -from core.workflow.callbacks.base_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode @@ -31,6 +32,7 @@ from models.workflow import ( WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, + WorkflowType, ) node_classes = { @@ -120,8 +122,7 @@ class WorkflowEngineManager: return default_config - def run_workflow(self, app_model: App, - workflow: Workflow, + def run_workflow(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, @@ -129,7 +130,6 @@ class WorkflowEngineManager: callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow - :param app_model: App instance :param workflow: Workflow instance :param triggered_from: triggered from :param user: account or end user @@ -143,13 +143,23 @@ class WorkflowEngineManager: if not graph: raise ValueError('workflow graph not found') + if 'nodes' not in graph or 'edges' not in graph: + raise ValueError('nodes or edges not found in workflow graph') + + if isinstance(graph.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if isinstance(graph.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + # init workflow run workflow_run = self._init_workflow_run( workflow=workflow, triggered_from=triggered_from, user=user, user_inputs=user_inputs, - system_inputs=system_inputs + system_inputs=system_inputs, + callbacks=callbacks ) # init workflow run state @@ -161,44 +171,54 @@ class WorkflowEngineManager: ) ) - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started(workflow_run) - - # fetch start node - start_node = self._get_entry_node(graph) - if not start_node: - self._workflow_run_failed( - workflow_run_state=workflow_run_state, - error='Start node not found in workflow graph', - callbacks=callbacks - ) - return + # fetch predecessor node ids before end node (include: llm, direct answer) + streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph) try: predecessor_node = None - current_node = start_node while True: - # run workflow - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=current_node, + # get next node, multiple target nodes in the future + next_node = self._get_next_node( + graph=graph, predecessor_node=predecessor_node, callbacks=callbacks ) - if current_node.node_type == NodeType.END: + if not next_node: break - # todo fetch next node until end node finished or no next node - current_node = None + # check if node is streamable + if next_node.node_id in streamable_node_ids: + next_node.stream_output_supported = True - if not current_node: - break + # max steps 30 reached + if len(workflow_run_state.workflow_node_executions) > 30: + raise ValueError('Max steps 30 reached.') - predecessor_node = current_node - # or max steps 30 reached # or max execution time 10min reached + if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600): + raise ValueError('Max execution time 10min reached.') + + # run workflow, run multiple target nodes in the future + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=next_node, + predecessor_node=predecessor_node, + callbacks=callbacks + ) + + if next_node.node_type == NodeType.END: + break + + predecessor_node = next_node + + if not predecessor_node and not next_node: + self._workflow_run_failed( + workflow_run_state=workflow_run_state, + error='Start node not found in workflow graph.', + callbacks=callbacks + ) + return except Exception as e: self._workflow_run_failed( workflow_run_state=workflow_run_state, @@ -213,11 +233,40 @@ class WorkflowEngineManager: callbacks=callbacks ) + def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param workflow: Workflow instance + :param graph: workflow graph + :return: + """ + workflow_type = WorkflowType.value_of(workflow.type) + + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + if workflow_type == WorkflowType.WORKFLOW: + if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': + end_node_ids.append(node_config.get('id')) + else: + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids + def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: + system_inputs: Optional[dict] = None, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: """ Init workflow run :param workflow: Workflow instance @@ -225,6 +274,7 @@ class WorkflowEngineManager: :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files + :param callbacks: workflow callbacks :return: """ try: @@ -260,6 +310,39 @@ class WorkflowEngineManager: db.session.rollback() raise + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started(workflow_run) + + return workflow_run + + def _workflow_run_success(self, workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run_state: workflow run state + :param callbacks: workflow callbacks + :return: + """ + workflow_run = workflow_run_state.workflow_run + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + + # fetch last workflow_node_executions + last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] + if last_workflow_node_execution: + workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) + + workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at + workflow_run.total_tokens = workflow_run_state.total_tokens + workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_run_finished(workflow_run) + return workflow_run def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, @@ -277,9 +360,8 @@ class WorkflowEngineManager: workflow_run.error = error workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_price = workflow_run_state.total_price - workflow_run.currency = workflow_run_state.currency workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) + workflow_run.finished_at = datetime.utcnow() db.session.commit() @@ -289,21 +371,77 @@ class WorkflowEngineManager: return workflow_run - def _get_entry_node(self, graph: dict) -> Optional[StartNode]: + def _get_next_node(self, graph: dict, + predecessor_node: Optional[BaseNode] = None, + callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: """ - Get entry node + Get next node + multiple target nodes in the future. :param graph: workflow graph + :param predecessor_node: predecessor node + :param callbacks: workflow callbacks :return: """ nodes = graph.get('nodes') if not nodes: return None - for node_config in nodes.items(): - if node_config.get('type') == NodeType.START.value: - return StartNode(config=node_config) + if not predecessor_node: + for node_config in nodes: + if node_config.get('type') == NodeType.START.value: + return StartNode(config=node_config) + else: + edges = graph.get('edges') + source_node_id = predecessor_node.node_id - return None + # fetch all outgoing edges from source node + outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] + if not outgoing_edges: + return None + + # fetch target node id from outgoing edges + outgoing_edge = None + source_handle = predecessor_node.node_run_result.edge_source_handle + if source_handle: + for edge in outgoing_edges: + if edge.get('source_handle') and edge.get('source_handle') == source_handle: + outgoing_edge = edge + break + else: + outgoing_edge = outgoing_edges[0] + + if not outgoing_edge: + return None + + target_node_id = outgoing_edge.get('target') + + # fetch target node from target node id + target_node_config = None + for node in nodes: + if node.get('id') == target_node_id: + target_node_config = node + break + + if not target_node_config: + return None + + # get next node + target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) + + return target_node( + config=target_node_config, + callbacks=callbacks + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + # TODO check queue is stopped + return time.perf_counter() - start_at > max_execution_time def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, @@ -320,28 +458,41 @@ class WorkflowEngineManager: # add to workflow node executions workflow_run_state.workflow_node_executions.append(workflow_node_execution) - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool, - callbacks=callbacks - ) - except Exception as e: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed self._workflow_node_execution_failed( workflow_node_execution=workflow_node_execution, - error=str(e), + start_at=start_at, + error=node_run_result.error, callbacks=callbacks ) - raise + raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") # node run success self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, + start_at=start_at, result=node_run_result, callbacks=callbacks ) + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + if node_run_result.metadata.get('total_tokens'): + workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) + return workflow_node_execution def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, @@ -384,3 +535,86 @@ class WorkflowEngineManager: callback.on_workflow_node_execute_started(workflow_node_execution) return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + result: NodeRunResult, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param result: node run result + :param callbacks: workflow callbacks + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(result.inputs) + workflow_node_execution.process_data = json.dumps(result.process_data) + workflow_node_execution.outputs = json.dumps(result.outputs) + workflow_node_execution.execution_metadata = json.dumps(result.metadata) + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_finished(workflow_node_execution) + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str, + callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :param callbacks: workflow callbacks + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_finished(workflow_node_execution) + + return workflow_node_execution + + def _append_variables_recursively(self, variable_pool: VariablePool, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param variable_pool: variable pool + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + variable_pool.append_variable( + node_id=node_id, + variable_key_list=variable_key_list, + value=variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + variable_pool=variable_pool, + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 85c9c2d2b2..572f472f1f 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -11,8 +11,6 @@ workflow_run_for_log_fields = { "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_at": TimestampField, "finished_at": TimestampField @@ -29,8 +27,6 @@ workflow_run_for_list_fields = { "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_at": TimestampField, @@ -56,8 +52,6 @@ workflow_run_detail_fields = { "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, - "total_price": fields.Float, - "currency": fields.String, "total_steps": fields.Integer, "created_by_role": fields.String, "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index 32ff26196c..032134a0d1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -216,8 +216,6 @@ class WorkflowRun(db.Model): - error (string) `optional` Error reason - elapsed_time (float) `optional` Time consumption (s) - total_tokens (int) `optional` Total tokens used - - total_price (decimal) `optional` Total cost - - currency (string) `optional` Currency, such as USD / RMB - total_steps (int) Total steps (redundant), default 0 - created_by_role (string) Creator role @@ -251,8 +249,6 @@ class WorkflowRun(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255)) total_steps = db.Column(db.Integer, server_default=db.text('0')) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) From a1bc6b50c5488bee749d1111dc979ec69255a447 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 6 Mar 2024 22:10:49 +0800 Subject: [PATCH 115/200] refactor workflow generate pipeline --- api/controllers/console/app/completion.py | 2 +- api/controllers/console/explore/completion.py | 2 +- api/controllers/service_api/app/completion.py | 2 +- api/controllers/web/completion.py | 2 +- api/core/agent/base_agent_runner.py | 2 +- api/core/agent/cot_agent_runner.py | 31 +- api/core/agent/fc_agent_runner.py | 30 +- api/core/app/app_queue_manager.py | 335 -------------- .../app/apps/advanced_chat/app_generator.py | 5 +- api/core/app/apps/advanced_chat/app_runner.py | 19 +- .../advanced_chat/generate_task_pipeline.py | 12 +- api/core/app/apps/agent_chat/app_generator.py | 5 +- api/core/app/apps/agent_chat/app_runner.py | 10 +- api/core/app/apps/base_app_queue_manager.py | 181 ++++++++ api/core/app/apps/base_app_runner.py | 58 ++- api/core/app/apps/chat/app_generator.py | 5 +- api/core/app/apps/chat/app_runner.py | 10 +- api/core/app/apps/completion/app_generator.py | 7 +- api/core/app/apps/completion/app_runner.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 25 +- .../app/apps/message_based_app_generator.py | 2 +- .../apps/message_based_app_queue_manager.py | 29 ++ api/core/app/apps/workflow/app_generator.py | 164 +++++++ .../app/apps/workflow/app_queue_manager.py | 23 + api/core/app/apps/workflow/app_runner.py | 156 +++++++ .../apps/workflow/generate_task_pipeline.py | 408 ++++++++++++++++++ api/core/app/entities/app_invoke_entities.py | 4 +- .../index_tool_callback_handler.py | 8 +- .../workflow_event_trigger_callback.py | 41 +- api/core/moderation/output_moderation.py | 19 +- api/services/workflow_service.py | 21 +- 31 files changed, 1175 insertions(+), 445 deletions(-) delete mode 100644 api/core/app/app_queue_manager.py create mode 100644 api/core/app/apps/base_app_queue_manager.py create mode 100644 api/core/app/apps/message_based_app_queue_manager.py create mode 100644 api/core/app/apps/workflow/app_generator.py create mode 100644 api/core/app/apps/workflow/app_queue_manager.py create mode 100644 api/core/app/apps/workflow/app_runner.py create mode 100644 api/core/app/apps/workflow/generate_task_pipeline.py diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fd6cfadfef..a7fd0164d8 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index dd531974fa..b8a5be0df0 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -21,7 +21,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5c488093fa..410fb5bffd 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -19,7 +19,7 @@ from controllers.service_api.app.error import ( ProviderQuotaExceededError, ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 785e2b8d6b..ed1378e7e3 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from controllers.web.wraps import WebApiResource -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 236a5d9cf7..0901b7e965 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -6,8 +6,8 @@ from mimetypes import guess_extension from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity -from core.app.app_queue_manager import AppQueueManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8b444ef3be..cbb19aca53 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -5,7 +5,8 @@ from typing import Literal, Union from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -121,7 +122,9 @@ class CotAgentRunner(BaseAgentRunner): ) if iteration_step > 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt messages prompt_messages = self._organize_cot_prompt_messages( @@ -163,7 +166,9 @@ class CotAgentRunner(BaseAgentRunner): # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: if isinstance(chunk, dict): @@ -225,7 +230,9 @@ class CotAgentRunner(BaseAgentRunner): llm_usage=usage_dict['usage']) if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) if not scratchpad.action: # failed to extract action, return final answer directly @@ -255,7 +262,9 @@ class CotAgentRunner(BaseAgentRunner): observation=answer, answer=answer, messages_ids=[]) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) else: # invoke tool error_response = None @@ -282,7 +291,9 @@ class CotAgentRunner(BaseAgentRunner): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) message_file_ids = [message_file.id for message_file, _ in message_files] except ToolProviderCredentialValidationError as e: @@ -318,7 +329,9 @@ class CotAgentRunner(BaseAgentRunner): answer=scratchpad.agent_response, messages_ids=message_file_ids, ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool message for prompt_tool in prompt_messages_tools: @@ -352,7 +365,7 @@ class CotAgentRunner(BaseAgentRunner): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( @@ -360,7 +373,7 @@ class CotAgentRunner(BaseAgentRunner): ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ -> Generator[Union[str, dict], None, None]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 30e5cdd694..7c3849a12c 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,7 +4,8 @@ from collections.abc import Generator from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -135,7 +136,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): @@ -195,7 +198,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): if not result.message.content: result.message.content = '' - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) yield LLMResultChunk( model=model_instance.model, @@ -233,8 +238,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): messages_ids=[], llm_usage=current_llm_usage ) - - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) final_answer += response + '\n' @@ -275,7 +281,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) # publish message file - self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) # add message file ids message_file_ids.append(message_file.id) @@ -331,7 +339,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): answer=None, messages_ids=message_file_ids ) - self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -341,15 +351,15 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish_message_end(LLMResult( + self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=final_answer, + content=final_answer ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' - ), PublishFrom.APPLICATION_MANAGER) + )), PublishFrom.APPLICATION_MANAGER) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ diff --git a/api/core/app/app_queue_manager.py b/api/core/app/app_queue_manager.py deleted file mode 100644 index 5655c8d979..0000000000 --- a/api/core/app/app_queue_manager.py +++ /dev/null @@ -1,335 +0,0 @@ -import queue -import time -from collections.abc import Generator -from enum import Enum -from typing import Any - -from sqlalchemy.orm import DeclarativeMeta - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueAgentMessageEvent, - QueueAgentThoughtEvent, - QueueAnnotationReplyEvent, - QueueErrorEvent, - QueueLLMChunkEvent, - QueueMessage, - QueueMessageEndEvent, - QueueMessageFileEvent, - QueueMessageReplaceEvent, - QueueNodeFinishedEvent, - QueueNodeStartedEvent, - QueuePingEvent, - QueueRetrieverResourcesEvent, - QueueStopEvent, - QueueTextChunkEvent, - QueueWorkflowFinishedEvent, - QueueWorkflowStartedEvent, -) -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from extensions.ext_redis import redis_client -from models.model import MessageAgentThought, MessageFile - - -class PublishFrom(Enum): - APPLICATION_MANAGER = 1 - TASK_PIPELINE = 2 - - -class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: - if not user_id: - raise ValueError("user is required") - - self._task_id = task_id - self._user_id = user_id - self._invoke_from = invoke_from - self._conversation_id = str(conversation_id) - self._app_mode = app_mode - self._message_id = str(message_id) - - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") - - q = queue.Queue() - - self._q = q - - def listen(self) -> Generator: - """ - Listen to queue - :return: - """ - # wait for 10 minutes to stop listen - listen_timeout = 600 - start_time = time.time() - last_ping_time = 0 - - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break - - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE - ) - self.stop_listen() - - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 - - def stop_listen(self) -> None: - """ - Stop listen to queue - :return: - """ - self._q.put(None) - - def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish llm chunk to channel - - :param chunk: llm chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueLLMChunkEvent( - chunk=chunk - ), pub_from) - - def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None: - """ - Publish text chunk to channel - - :param text: text - :param pub_from: publish from - :return: - """ - self.publish(QueueTextChunkEvent( - text=text - ), pub_from) - - def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: - """ - Publish agent chunk message to channel - - :param chunk: chunk - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentMessageEvent( - chunk=chunk - ), pub_from) - - def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None: - """ - Publish message replace - :param text: text - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageReplaceEvent( - text=text - ), pub_from) - - def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None: - """ - Publish retriever resources - :return: - """ - self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from) - - def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None: - """ - Publish annotation reply - :param message_annotation_id: message annotation id - :param pub_from: publish from - :return: - """ - self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from) - - def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None: - """ - Publish message end - :param llm_result: llm result - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from) - self.stop_listen() - - def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None: - """ - Publish workflow started - :param workflow_run_id: workflow run id - :param pub_from: publish from - :return: - """ - self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from) - - def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None: - """ - Publish workflow finished - :param workflow_run_id: workflow run id - :param pub_from: publish from - :return: - """ - self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from) - - def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: - """ - Publish node started - :param workflow_node_execution_id: workflow node execution id - :param pub_from: publish from - :return: - """ - self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) - - def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None: - """ - Publish node finished - :param workflow_node_execution_id: workflow node execution id - :param pub_from: publish from - :return: - """ - self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from) - - def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_agent_thought: message agent thought - :param pub_from: publish from - :return: - """ - self.publish(QueueAgentThoughtEvent( - agent_thought_id=message_agent_thought.id - ), pub_from) - - def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None: - """ - Publish agent thought - :param message_file: message file - :param pub_from: publish from - :return: - """ - self.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), pub_from) - - def publish_error(self, e, pub_from: PublishFrom) -> None: - """ - Publish error - :param e: error - :param pub_from: publish from - :return: - """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) - self.stop_listen() - - def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: - """ - Publish event to queue - :param event: - :param pub_from: - :return: - """ - self._check_for_sqlalchemy_models(event.dict()) - - message = QueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event - ) - - self._q.put(message) - - if isinstance(event, QueueStopEvent): - self.stop_listen() - - if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise ConversationTaskStoppedException() - - @classmethod - def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: - """ - Set task stop flag - :return: - """ - result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) - if result is None: - return - - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": - return - - stopped_cache_key = cls._generate_stopped_cache_key(task_id) - redis_client.setex(stopped_cache_key, 600, 1) - - def _is_stopped(self) -> bool: - """ - Check if task is stopped - :return: - """ - stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) - result = redis_client.get(stopped_cache_key) - if result is not None: - return True - - return False - - @classmethod - def _generate_task_belong_cache_key(cls, task_id: str) -> str: - """ - Generate task belong cache key - :param task_id: task id - :return: - """ - return f"generate_task_belong:{task_id}" - - @classmethod - def _generate_stopped_cache_key(cls, task_id: str) -> str: - """ - Generate stopped cache key - :param task_id: task id - :return: - """ - return f"generate_task_stopped:{task_id}" - - def _check_for_sqlalchemy_models(self, data: Any): - # from entity to dict or list - if isinstance(data, dict): - for key, value in data.items(): - self._check_for_sqlalchemy_models(value) - elif isinstance(data, list): - for item in data: - self._check_for_sqlalchemy_models(item) - else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") - - -class ConversationTaskStoppedException(Exception): - pass diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 937f95679a..a19a5c8f67 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -8,11 +8,12 @@ from flask import Flask, current_app from pydantic import ValidationError from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -101,7 +102,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c5ffa80165..8fff8fc37e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -2,14 +2,14 @@ import logging import time from typing import cast -from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueStopEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable @@ -93,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)], + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] ) def handle_input_moderation(self, queue_manager: AppQueueManager, @@ -153,9 +153,9 @@ class AdvancedChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) self._stream_output( @@ -182,7 +182,11 @@ class AdvancedChatAppRunner(AppRunner): if stream: index = 0 for token in text: - queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) index += 1 time.sleep(0.01) @@ -190,4 +194,3 @@ class AdvancedChatAppRunner(AppRunner): QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER ) - queue_manager.stop_listen() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index cfeb46f05a..84352f16c7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -6,7 +6,7 @@ from typing import Optional, Union from pydantic import BaseModel -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -46,6 +46,7 @@ class TaskState(BaseModel): """ answer: str = "" metadata: dict = {} + usage: LLMUsage class AdvancedChatAppGenerateTaskPipeline: @@ -349,7 +350,12 @@ class AdvancedChatAppGenerateTaskPipeline: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE @@ -558,5 +564,5 @@ class AdvancedChatAppGenerateTaskPipeline: type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), - on_message_replace_func=self._queue_manager.publish_message_replace + queue_manager=self._queue_manager ) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index d5dbdf0dd2..6d27620a09 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -9,10 +9,11 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -119,7 +120,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 27a473fb17..2e142c63f1 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -4,10 +4,11 @@ from typing import cast from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -120,10 +121,11 @@ class AgentChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py new file mode 100644 index 0000000000..0391599040 --- /dev/null +++ b/api/core/app/apps/base_app_queue_manager.py @@ -0,0 +1,181 @@ +import queue +import time +from abc import abstractmethod +from collections.abc import Generator +from enum import Enum +from typing import Any + +from sqlalchemy.orm import DeclarativeMeta + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueuePingEvent, + QueueStopEvent, +) +from extensions.ext_redis import redis_client + + +class PublishFrom(Enum): + APPLICATION_MANAGER = 1 + TASK_PIPELINE = 2 + + +class AppQueueManager: + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom) -> None: + if not user_id: + raise ValueError("user is required") + + self._task_id = task_id + self._user_id = user_id + self._invoke_from = invoke_from + + user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + + q = queue.Queue() + + self._q = q + + def listen(self) -> Generator: + """ + Listen to queue + :return: + """ + # wait for 10 minutes to stop listen + listen_timeout = 600 + start_time = time.time() + last_ping_time = 0 + + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break + + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), + PublishFrom.TASK_PIPELINE + ) + + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + + def stop_listen(self) -> None: + """ + Stop listen to queue + :return: + """ + self._q.put(None) + + def publish_error(self, e, pub_from: PublishFrom) -> None: + """ + Publish error + :param e: error + :param pub_from: publish from + :return: + """ + self.publish(QueueErrorEvent( + error=e + ), pub_from) + + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + self._check_for_sqlalchemy_models(event.dict()) + + message = self.construct_queue_message(event) + + self._q.put(message) + + if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise ConversationTaskStoppedException() + + @abstractmethod + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + raise NotImplementedError + + @classmethod + def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: + """ + Set task stop flag + :return: + """ + result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) + if result is None: + return + + user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' + if result.decode('utf-8') != f"{user_prefix}-{user_id}": + return + + stopped_cache_key = cls._generate_stopped_cache_key(task_id) + redis_client.setex(stopped_cache_key, 600, 1) + + def _is_stopped(self) -> bool: + """ + Check if task is stopped + :return: + """ + stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id) + result = redis_client.get(stopped_cache_key) + if result is not None: + return True + + return False + + @classmethod + def _generate_task_belong_cache_key(cls, task_id: str) -> str: + """ + Generate task belong cache key + :param task_id: task id + :return: + """ + return f"generate_task_belong:{task_id}" + + @classmethod + def _generate_stopped_cache_key(cls, task_id: str) -> str: + """ + Generate stopped cache key + :param task_id: task id + :return: + """ + return f"generate_task_stopped:{task_id}" + + def _check_for_sqlalchemy_models(self, data: Any): + # from entity to dict or list + if isinstance(data, dict): + for key, value in data.items(): + self._check_for_sqlalchemy_models(value) + elif isinstance(data, list): + for item in data: + self._check_for_sqlalchemy_models(item) + else: + if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): + raise TypeError("Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed.") + + +class ConversationTaskStoppedException(Exception): + pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index dda240d778..e7ce7f25ef 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -3,13 +3,14 @@ from collections.abc import Generator from typing import Optional, Union, cast from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AppGenerateEntity, EasyUIBasedAppGenerateEntity, InvokeFrom, ModelConfigWithCredentialsEntity, ) +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch @@ -187,25 +188,32 @@ class AppRunner: if stream: index = 0 for token in text: - queue_manager.publish_llm_chunk(LLMResultChunk( + chunk = LLMResultChunk( model=app_generate_entity.model_config.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=token) ) - ), PublishFrom.APPLICATION_MANAGER) + ) + + queue_manager.publish( + QueueLLMChunkEvent( + chunk=chunk + ), PublishFrom.APPLICATION_MANAGER + ) index += 1 time.sleep(0.01) - queue_manager.publish_message_end( - llm_result=LLMResult( - model=app_generate_entity.model_config.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() - ), - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=app_generate_entity.model_config.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=text), + usage=usage if usage else LLMUsage.empty_usage() + ), + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], @@ -241,9 +249,10 @@ class AppRunner: :param queue_manager: application queue manager :return: """ - queue_manager.publish_message_end( - llm_result=invoke_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=invoke_result, + ), PublishFrom.APPLICATION_MANAGER ) def _handle_invoke_result_stream(self, invoke_result: Generator, @@ -261,9 +270,17 @@ class AppRunner: usage = None for result in invoke_result: if not agent: - queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueLLMChunkEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) else: - queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + queue_manager.publish( + QueueAgentMessageEvent( + chunk=result + ), PublishFrom.APPLICATION_MANAGER + ) text += result.delta.message.content @@ -286,9 +303,10 @@ class AppRunner: usage=usage ) - queue_manager.publish_message_end( - llm_result=llm_result, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueMessageEndEvent( + llm_result=llm_result, + ), PublishFrom.APPLICATION_MANAGER ) def moderation_for_inputs(self, app_id: str, @@ -311,7 +329,7 @@ class AppRunner: tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query, + query=query if query else '' ) def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 978ac9656b..7ddf8dfe32 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -9,10 +9,11 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -119,7 +120,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) = self._init_generate_records(application_generate_entity, conversation) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index bce4606f21..d51f3db540 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,12 +1,13 @@ import logging from typing import cast -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.entities.app_invoke_entities import ( ChatAppGenerateEntity, ) +from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -117,10 +118,11 @@ class ChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish_annotation_reply( - message_annotation_id=annotation_reply.id, - pub_from=PublishFrom.APPLICATION_MANAGER + queue_manager.publish( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), + PublishFrom.APPLICATION_MANAGER ) + self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9355bae123..7150bee3ce 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -9,10 +9,11 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -112,7 +113,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) = self._init_generate_records(application_generate_entity) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, @@ -263,7 +264,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) = self._init_generate_records(application_generate_entity) # init queue manager - queue_manager = AppQueueManager( + queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index d67d485e1d..04adf77be5 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,7 +1,7 @@ import logging from typing import cast -from core.app.app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.entities.app_invoke_entities import ( diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 80596668b8..856bfb623d 100644 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -6,7 +6,7 @@ from typing import Optional, Union, cast from pydantic import BaseModel -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, @@ -385,14 +385,19 @@ class EasyUIBasedGenerateTaskPipeline: if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish_llm_chunk(LLMResultChunk( - model=self._task_state.llm_result.model, - prompt_messages=self._task_state.llm_result.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) - ), PublishFrom.TASK_PIPELINE) + self._queue_manager.publish( + QueueLLMChunkEvent( + chunk=LLMResultChunk( + model=self._task_state.llm_result.model, + prompt_messages=self._task_state.llm_result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) + ) + ) + ), PublishFrom.TASK_PIPELINE + ) + self._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE @@ -664,5 +669,5 @@ class EasyUIBasedGenerateTaskPipeline: type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config ), - on_message_replace_func=self._queue_manager.publish_message_replace + queue_manager=self._queue_manager ) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index dab72bd6d6..3dee68b5e1 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -6,8 +6,8 @@ from typing import Optional, Union from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom -from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py new file mode 100644 index 0000000000..ed9475502d --- /dev/null +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -0,0 +1,29 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueMessage, +) + + +class MessageBasedAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + conversation_id: str, + app_mode: str, + message_id: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._conversation_id = str(conversation_id) + self._app_mode = app_mode + self._message_id = str(message_id) + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return QueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py new file mode 100644 index 0000000000..891ca4c2be --- /dev/null +++ b/api/core/app/apps/workflow/app_generator.py @@ -0,0 +1,164 @@ +import logging +import threading +import uuid +from collections.abc import Generator +from typing import Union + +from flask import Flask, current_app +from pydantic import ValidationError + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.file.message_file_parser import MessageFileParser +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowAppGenerator(BaseAppGenerator): + def generate(self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True) \ + -> Union[dict, Generator]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + inputs = args['inputs'] + + # parse files + files = args['files'] if 'files' in args and args['files'] else [] + message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict) + if file_upload_entity: + file_objs = message_file_parser.validate_and_transform_files_arg( + files, + file_upload_entity, + user + ) + else: + file_objs = [] + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs=self._get_cleaned_inputs(inputs, app_config), + files=file_objs, + user_id=user.id, + stream=stream, + invoke_from=invoke_from + ) + + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode + ) + + # new thread + worker_thread = threading.Thread(target=self._generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'application_generate_entity': application_generate_entity, + 'queue_manager': queue_manager + }) + + worker_thread.start() + + # return response or stream generator + return self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream + ) + + def _generate_worker(self, flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :return: + """ + with flask_app.app_context(): + try: + # workflow app + runner = WorkflowAppRunner() + runner.run( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) + except ConversationTaskStoppedException: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError('Incorrect API key provided'), + PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except (ValueError, InvokeError) as e: + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.remove() + + def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + stream: bool = False) -> Union[dict, Generator]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param stream: is stream + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if e.args[0] == "I/O operation on closed file.": # ignore this error + raise ConversationTaskStoppedException() + else: + logger.exception(e) + raise e + finally: + db.session.remove() diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py new file mode 100644 index 0000000000..0f9b0a1c78 --- /dev/null +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -0,0 +1,23 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueMessage, +) + + +class WorkflowAppQueueManager(AppQueueManager): + def __init__(self, task_id: str, + user_id: str, + invoke_from: InvokeFrom, + app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + return QueueMessage( + task_id=self._task_id, + app_mode=self._app_mode, + event=event + ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py new file mode 100644 index 0000000000..e675026e41 --- /dev/null +++ b/api/core/app/apps/workflow/app_runner.py @@ -0,0 +1,156 @@ +import logging +import time +from typing import cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.entities.app_invoke_entities import ( + AppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent +from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback +from core.moderation.base import ModerationException +from core.moderation.input_moderation import InputModeration +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.workflow_engine_manager import WorkflowEngineManager +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser +from models.workflow import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class WorkflowAppRunner: + """ + Workflow Application Runner + """ + + def run(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager) -> None: + """ + Run application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :return: + """ + app_config = application_generate_entity.app_config + app_config = cast(WorkflowAppConfig, app_config) + + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + if not app_record: + raise ValueError("App not found") + + workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + inputs = application_generate_entity.inputs + files = application_generate_entity.files + + # moderation + if self.handle_input_moderation( + queue_manager=queue_manager, + app_record=app_record, + app_generate_entity=application_generate_entity, + inputs=inputs + ): + return + + # fetch user + if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: + user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() + else: + user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + + # RUN WORKFLOW + workflow_engine_manager = WorkflowEngineManager() + workflow_engine_manager.run_workflow( + workflow=workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, + user=user, + user_inputs=inputs, + system_inputs={ + SystemVariable.FILES: files + }, + callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + ) + + def handle_input_moderation(self, queue_manager: AppQueueManager, + app_record: App, + app_generate_entity: WorkflowAppGenerateEntity, + inputs: dict) -> bool: + """ + Handle input moderation + :param queue_manager: application queue manager + :param app_record: app record + :param app_generate_entity: application generate entity + :param inputs: inputs + :return: + """ + try: + # process sensitive_word_avoidance + moderation_feature = InputModeration() + _, inputs, query = moderation_feature.check( + app_id=app_record.id, + tenant_id=app_generate_entity.app_config.tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) + except ModerationException as e: + if app_generate_entity.stream: + self._stream_output( + queue_manager=queue_manager, + text=str(e), + ) + + queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION), + PublishFrom.APPLICATION_MANAGER + ) + return True + + return False + + def _stream_output(self, queue_manager: AppQueueManager, + text: str) -> None: + """ + Direct output + :param queue_manager: application queue manager + :param text: text + :return: + """ + index = 0 + for token in text: + queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.APPLICATION_MANAGER + ) + index += 1 + time.sleep(0.01) + + def moderation_for_inputs(self, app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: dict) -> tuple[bool, dict, str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=inputs, + query='' + ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py new file mode 100644 index 0000000000..df83ad634e --- /dev/null +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -0,0 +1,408 @@ +import json +import logging +import time +from collections.abc import Generator +from typing import Optional, Union + +from pydantic import BaseModel + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import ( + WorkflowAppGenerateEntity, +) +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageReplaceEvent, + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.moderation.output_moderation import ModerationRule, OutputModeration +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus + +logger = logging.getLogger(__name__) + + +class TaskState(BaseModel): + """ + TaskState entity + """ + answer: str = "" + metadata: dict = {} + workflow_run_id: Optional[str] = None + + +class WorkflowAppGenerateTaskPipeline: + """ + WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. + """ + + def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + stream: bool) -> None: + """ + Initialize GenerateTaskPipeline. + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + """ + self._application_generate_entity = application_generate_entity + self._queue_manager = queue_manager + self._task_state = TaskState() + self._start_at = time.perf_counter() + self._output_moderation_handler = self._init_output_moderation() + self._stream = stream + + def process(self) -> Union[dict, Generator]: + """ + Process generate task pipeline. + :return: + """ + if self._stream: + return self._process_stream_response() + else: + return self._process_blocking_response() + + def _process_blocking_response(self) -> dict: + """ + Process blocking response. + :return: + """ + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueueErrorEvent): + raise self._handle_error(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: + workflow_run = self._get_workflow_run(event.workflow_run_id) + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + return response + else: + continue + + def _process_stream_response(self) -> Generator: + """ + Process stream response. + :return: + """ + for message in self._queue_manager.listen(): + event = message.event + + if isinstance(event, QueueErrorEvent): + data = self._error_to_stream_response_data(self._handle_error(event)) + yield self._yield_response(data) + break + elif isinstance(event, QueueWorkflowStartedEvent): + self._task_state.workflow_run_id = event.workflow_run_id + + workflow_run = self._get_workflow_run(event.workflow_run_id) + response = { + 'event': 'workflow_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'created_at': int(workflow_run.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeStartedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_started', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueNodeFinishedEvent): + workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + response = { + 'event': 'node_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': workflow_node_execution.workflow_run_id, + 'data': { + 'id': workflow_node_execution.id, + 'node_id': workflow_node_execution.node_id, + 'index': workflow_node_execution.index, + 'predecessor_node_id': workflow_node_execution.predecessor_node_id, + 'inputs': workflow_node_execution.inputs_dict, + 'process_data': workflow_node_execution.process_data_dict, + 'outputs': workflow_node_execution.outputs_dict, + 'status': workflow_node_execution.status, + 'error': workflow_node_execution.error, + 'elapsed_time': workflow_node_execution.elapsed_time, + 'execution_metadata': workflow_node_execution.execution_metadata_dict, + 'created_at': int(workflow_node_execution.created_at.timestamp()), + 'finished_at': int(workflow_node_execution.finished_at.timestamp()) + } + } + + yield self._yield_response(response) + elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: + workflow_run = self._get_workflow_run(event.workflow_run_id) + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + # response moderation + if self._output_moderation_handler: + self._output_moderation_handler.stop_thread() + + self._task_state.answer = self._output_moderation_handler.moderation_completion( + completion=self._task_state.answer, + public_event=False + ) + + self._output_moderation_handler = None + + replace_response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': self._task_state.answer + } + } + + yield self._yield_response(replace_response) + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) + } + } + + yield self._yield_response(workflow_run_response) + elif isinstance(event, QueueTextChunkEvent): + delta_text = event.chunk_text + if delta_text is None: + continue + + if self._output_moderation_handler: + if self._output_moderation_handler.should_direct_output(): + # stop subscribe new token when output moderation should direct output + self._task_state.answer = self._output_moderation_handler.get_final_output() + self._queue_manager.publish( + QueueTextChunkEvent( + text=self._task_state.answer + ), PublishFrom.TASK_PIPELINE + ) + + self._queue_manager.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), + PublishFrom.TASK_PIPELINE + ) + continue + else: + self._output_moderation_handler.append_new_token(delta_text) + + self._task_state.answer += delta_text + response = self._handle_chunk(delta_text) + yield self._yield_response(response) + elif isinstance(event, QueueMessageReplaceEvent): + response = { + 'event': 'text_replace', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + 'data': { + 'text': event.text + } + } + + yield self._yield_response(response) + elif isinstance(event, QueuePingEvent): + yield "event: ping\n\n" + else: + continue + + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Get workflow run. + :param workflow_run_id: workflow run id + :return: + """ + return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: + """ + Get workflow node execution. + :param workflow_node_execution_id: workflow node execution id + :return: + """ + return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + + def _handle_chunk(self, text: str) -> dict: + """ + Handle completed event. + :param text: text + :return: + """ + response = { + 'event': 'text_chunk', + 'workflow_run_id': self._task_state.workflow_run_id, + 'task_id': self._application_generate_entity.task_id, + 'data': { + 'text': text + } + } + + return response + + def _handle_error(self, event: QueueErrorEvent) -> Exception: + """ + Handle error event. + :param event: event + :return: + """ + logger.debug("error: %s", event.error) + e = event.error + + if isinstance(e, InvokeAuthorizationError): + return InvokeAuthorizationError('Incorrect API key provided') + elif isinstance(e, InvokeError) or isinstance(e, ValueError): + return e + else: + return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + + def _error_to_stream_response_data(self, e: Exception) -> dict: + """ + Error to stream response. + :param e: exception + :return: + """ + error_responses = { + ValueError: {'code': 'invalid_param', 'status': 400}, + ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + QuotaExceededError: { + 'code': 'provider_quota_exceeded', + 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + 'status': 400 + }, + ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, + InvokeError: {'code': 'completion_request_error', 'status': 400} + } + + # Determine the response based on the type of exception + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + + if data: + data.setdefault('message', getattr(e, 'description', str(e))) + else: + logging.error(e) + data = { + 'code': 'internal_server_error', + 'message': 'Internal Server Error, please contact support.', + 'status': 500 + } + + return { + 'event': 'error', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': self._task_state.workflow_run_id, + **data + } + + def _yield_response(self, response: dict) -> str: + """ + Yield response. + :param response: response + :return: + """ + return "data: " + json.dumps(response) + "\n\n" + + def _init_output_moderation(self) -> Optional[OutputModeration]: + """ + Init output moderation. + :return: + """ + app_config = self._application_generate_entity.app_config + sensitive_word_avoidance = app_config.sensitive_word_avoidance + + if sensitive_word_avoidance: + return OutputModeration( + tenant_id=app_config.tenant_id, + app_id=app_config.app_id, + rule=ModerationRule( + type=sensitive_word_avoidance.type, + config=sensitive_word_avoidance.config + ), + queue_manager=self._queue_manager + ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1c4f32b8f2..01cbd7d2b2 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -127,9 +127,9 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): query: Optional[str] = None -class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity): +class WorkflowAppGenerateEntity(AppGenerateEntity): """ - Workflow UI Based Application Generate Entity. + Workflow Application Generate Entity. """ # app config app_config: WorkflowUIBasedAppConfig diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index ca781a55bc..8e1f496b22 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,7 @@ -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment @@ -82,4 +83,7 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_retriever_resource) db.session.commit() - self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish( + QueueRetrieverResourcesEvent(retriever_resources=resource), + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/callback_handler/workflow_event_trigger_callback.py index 80dabc7548..f8bad94252 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/callback_handler/workflow_event_trigger_callback.py @@ -1,4 +1,11 @@ -from core.app.app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from models.workflow import WorkflowNodeExecution, WorkflowRun @@ -12,43 +19,45 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): """ Workflow run started """ - self._queue_manager.publish_workflow_started( - workflow_run_id=workflow_run.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: """ Workflow run finished """ - self._queue_manager.publish_workflow_finished( - workflow_run_id=workflow_run.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: """ Workflow node execute started """ - self._queue_manager.publish_node_started( - workflow_node_execution_id=workflow_node_execution.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER ) def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: """ Workflow node execute finished """ - self._queue_manager.publish_node_finished( - workflow_node_execution_id=workflow_node_execution.id, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER ) + def on_text_chunk(self, text: str) -> None: """ Publish text chunk """ - self._queue_manager.publish_text_chunk( - text=text, - pub_from=PublishFrom.TASK_PIPELINE + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 749ee431e8..af8910614d 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -6,7 +6,8 @@ from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel -from core.app.app_queue_manager import PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory @@ -25,7 +26,7 @@ class OutputModeration(BaseModel): app_id: str rule: ModerationRule - on_message_replace_func: Any + queue_manager: AppQueueManager thread: Optional[threading.Thread] = None thread_running: bool = True @@ -67,7 +68,12 @@ class OutputModeration(BaseModel): final_output = result.text if public_event: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) return final_output @@ -117,7 +123,12 @@ class OutputModeration(BaseModel): # trigger replace event if self.thread_running: - self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output + ), + PublishFrom.TASK_PIPELINE + ) if result.action == ModerationAction.DIRECT_OUTPUT: break diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2c1b6eb819..144d136bdc 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,6 +6,7 @@ from typing import Optional, Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeType from core.workflow.workflow_engine_manager import WorkflowEngineManager @@ -175,8 +176,24 @@ class WorkflowService: user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom) -> Union[dict, Generator]: - # TODO - pass + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow + app_generator = WorkflowAppGenerator() + response = app_generator.generate( + app_model=app_model, + workflow=draft_workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=True + ) + + return response def convert_to_workflow(self, app_model: App, account: Account) -> App: """ From 079cc082a36252b841735952530aace430ec6ff1 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 09:55:29 +0800 Subject: [PATCH 116/200] use callback to filter workflow stream output --- api/core/app/apps/advanced_chat/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 41 +++++++-- api/core/app/apps/workflow/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 87 +++++++++++++++++++ .../callbacks/base_workflow_callback.py | 6 +- api/core/workflow/nodes/base_node.py | 11 +-- api/core/workflow/workflow_engine_manager.py | 36 -------- 7 files changed, 138 insertions(+), 57 deletions(-) rename api/core/{callback_handler => app/apps/advanced_chat}/workflow_event_trigger_callback.py (55%) create mode 100644 api/core/app/apps/workflow/workflow_event_trigger_callback.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8fff8fc37e..077f0c2de0 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,7 @@ import time from typing import cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( @@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager @@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py similarity index 55% rename from api/core/callback_handler/workflow_event_trigger_callback.py rename to api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index f8bad94252..44fb5905b0 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -7,13 +7,15 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun class WorkflowEventTriggerCallback(BaseWorkflowCallback): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): PublishFrom.APPLICATION_MANAGER ) - - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e675026e41..132282ffe3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,13 +4,13 @@ from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable @@ -76,7 +76,10 @@ class WorkflowAppRunner: system_inputs={ SystemVariable.FILES: files }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py new file mode 100644 index 0000000000..57775f2cce --- /dev/null +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -0,0 +1,87 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): + self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish( + QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish( + QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish( + QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish( + QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_node_text_chunk(self, node_id: str, text: str) -> None: + """ + Publish text chunk + """ + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3425b2b03c..3866bf2c15 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,9 +1,9 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from models.workflow import WorkflowNodeExecution, WorkflowRun -class BaseWorkflowCallback: +class BaseWorkflowCallback(ABC): @abstractmethod def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -33,7 +33,7 @@ class BaseWorkflowCallback: raise NotImplementedError @abstractmethod - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index efffdfae1a..1ff05f9f4e 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -16,7 +16,6 @@ class BaseNode: node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None - stream_output_supported: bool = False callbacks: list[BaseWorkflowCallback] def __init__(self, config: dict, @@ -71,10 +70,12 @@ class BaseNode: :param text: chunk text :return: """ - if self.stream_output_supported: - if self.callbacks: - for callback in self.callbacks: - callback.on_text_chunk(text) + if self.callbacks: + for callback in self.callbacks: + callback.on_node_text_chunk( + node_id=self.node_id, + text=text + ) @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 908b684930..4d881d3d04 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -32,7 +32,6 @@ from models.workflow import ( WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, - WorkflowType, ) node_classes = { @@ -171,9 +170,6 @@ class WorkflowEngineManager: ) ) - # fetch predecessor node ids before end node (include: llm, direct answer) - streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph) - try: predecessor_node = None while True: @@ -187,10 +183,6 @@ class WorkflowEngineManager: if not next_node: break - # check if node is streamable - if next_node.node_id in streamable_node_ids: - next_node.stream_output_supported = True - # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -233,34 +225,6 @@ class WorkflowEngineManager: callbacks=callbacks ) - def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param workflow: Workflow instance - :param graph: workflow graph - :return: - """ - workflow_type = WorkflowType.value_of(workflow.type) - - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: - if workflow_type == WorkflowType.WORKFLOW: - if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': - end_node_ids.append(node_config.get('id')) - else: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids - def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser], From 3e54cb26beee1c23c31c8eaa2f01ef32a9e8f471 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 10:09:23 +0800 Subject: [PATCH 117/200] move funcs --- api/core/workflow/workflow_engine_manager.py | 25 -------------------- api/services/workflow_service.py | 14 +++++++---- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 4d881d3d04..8ab0eb4802 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -51,30 +51,6 @@ node_classes = { class WorkflowEngineManager: - def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: - """ - Get draft workflow - """ - # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() - - # return draft workflow - return workflow - - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: - """ - Get published workflow - """ - if not app_model.workflow_id: - return None - - # fetch published workflow by workflow_id - return self.get_workflow(app_model, app_model.workflow_id) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ Get workflow @@ -404,7 +380,6 @@ class WorkflowEngineManager: :param max_execution_time: max execution time :return: """ - # TODO check queue is stopped return time.perf_counter() - start_at > max_execution_time def _run_workflow_node(self, workflow_run_state: WorkflowRunState, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 144d136bdc..833c22cdff 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -26,22 +26,28 @@ class WorkflowService: """ Get draft workflow """ - workflow_engine_manager = WorkflowEngineManager() + # fetch draft workflow by app_model + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == 'draft' + ).first() # return draft workflow - return workflow_engine_manager.get_draft_workflow(app_model=app_model) + return workflow def get_published_workflow(self, app_model: App) -> Optional[Workflow]: """ Get published workflow """ + if not app_model.workflow_id: return None workflow_engine_manager = WorkflowEngineManager() - # return published workflow - return workflow_engine_manager.get_published_workflow(app_model=app_model) + # fetch published workflow by workflow_id + return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id) def sync_draft_workflow(self, app_model: App, graph: dict, From 8684b172d201ef9414a6dff756f42f5439f809f0 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 15:43:55 +0800 Subject: [PATCH 118/200] add start, end, direct answer node --- .../entities/base_node_data_entities.py | 2 - api/core/workflow/entities/node_entities.py | 13 ++++- .../workflow/entities/variable_entities.py | 9 +++ .../workflow/entities/workflow_entities.py | 7 ++- api/core/workflow/nodes/base_node.py | 4 +- .../nodes/direct_answer/direct_answer_node.py | 51 ++++++++++++++++- .../workflow/nodes/direct_answer/entities.py | 10 ++++ api/core/workflow/nodes/end/end_node.py | 57 ++++++++++++++++++- api/core/workflow/nodes/end/entities.py | 43 ++++++++++++++ api/core/workflow/nodes/llm/entities.py | 8 +++ api/core/workflow/nodes/llm/llm_node.py | 21 ++++++- api/core/workflow/nodes/start/entities.py | 16 +----- api/core/workflow/nodes/start/start_node.py | 56 ++++++++++++++++-- api/core/workflow/workflow_engine_manager.py | 8 ++- 14 files changed, 274 insertions(+), 31 deletions(-) create mode 100644 api/core/workflow/entities/variable_entities.py create mode 100644 api/core/workflow/nodes/direct_answer/entities.py create mode 100644 api/core/workflow/nodes/llm/entities.py diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index afa6ddff04..fc6ee231ff 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -5,7 +5,5 @@ from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): - type: str - title: str desc: Optional[str] = None diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index af539692ef..263172da31 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel @@ -46,6 +46,15 @@ class SystemVariable(Enum): CONVERSATION = 'conversation' +class NodeRunMetadataKey(Enum): + """ + Node Run Metadata Key. + """ + TOTAL_TOKENS = 'total_tokens' + TOTAL_PRICE = 'total_price' + CURRENCY = 'currency' + + class NodeRunResult(BaseModel): """ Node Run Result. @@ -55,7 +64,7 @@ class NodeRunResult(BaseModel): inputs: Optional[dict] = None # node inputs process_data: Optional[dict] = None # process data outputs: Optional[dict] = None # node outputs - metadata: Optional[dict] = None # node metadata + metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py new file mode 100644 index 0000000000..19d9af2a61 --- /dev/null +++ b/api/core/workflow/entities/variable_entities.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class VariableSelector(BaseModel): + """ + Variable Selector. + """ + variable: str + value_selector: list[str] diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 0d78e4c4f1..8c15cb95cd 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun class WorkflowRunState: workflow_run: WorkflowRun start_at: float + user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_node_executions: list[WorkflowNodeExecution] = [] - def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: + def __init__(self, workflow_run: WorkflowRun, + start_at: float, + user_inputs: dict, + variable_pool: VariablePool) -> None: self.workflow_run = workflow_run self.start_at = start_at + self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 1ff05f9f4e..6720017d9f 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecutionStatus -class BaseNode: +class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index c6013974b8..80ecdf7757 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,5 +1,54 @@ +import time +from typing import Optional, cast + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData +from models.workflow import WorkflowNodeExecutionStatus class DirectAnswerNode(BaseNode): - pass + _node_data_cls = DirectAnswerNodeData + node_type = NodeType.DIRECT_ANSWER + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + if variable_pool is None and run_args: + raise ValueError("Not support single step debug.") + + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + answer = template_parser.format(variable_values) + + # publish answer as stream + for word in answer: + self.publish_text_chunk(word) + time.sleep(0.01) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variable_values, + output={ + "answer": answer + } + ) diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/direct_answer/entities.py new file mode 100644 index 0000000000..e7c11e3c4d --- /dev/null +++ b/api/core/workflow/nodes/direct_answer/entities.py @@ -0,0 +1,10 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class DirectAnswerNodeData(BaseNodeData): + """ + DirectAnswer Node Data. + """ + variables: list[VariableSelector] = [] + answer: str diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index f9aea89af7..62429e3ac2 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,60 @@ +from typing import Optional, cast + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs +from models.workflow import WorkflowNodeExecutionStatus class EndNode(BaseNode): - pass + _node_data_cls = EndNodeData + node_type = NodeType.END + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + outputs_config = node_data.outputs + + if variable_pool is not None: + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} + else: + raise ValueError("Not support single step debug.") + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=outputs, + outputs=outputs + ) diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index 045e7effc4..32212ae7fa 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,4 +1,10 @@ from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class EndNodeOutputType(Enum): @@ -23,3 +29,40 @@ class EndNodeOutputType(Enum): if output_type.value == value: return output_type raise ValueError(f'invalid output type value {value}') + + +class EndNodeDataOutputs(BaseModel): + """ + END Node Data Outputs. + """ + class OutputType(Enum): + """ + Output Types. + """ + NONE = 'none' + PLAIN_TEXT = 'plain-text' + STRUCTURED = 'structured' + + @classmethod + def value_of(cls, value: str) -> 'OutputType': + """ + Get value of given output type. + + :param value: output type value + :return: output type + """ + for output_type in cls: + if output_type.value == value: + return output_type + raise ValueError(f'invalid output type value {value}') + + type: OutputType = OutputType.NONE + plain_text_selector: Optional[list[str]] = None + structured_variables: Optional[list[VariableSelector]] = None + + +class EndNodeData(BaseNodeData): + """ + END Node Data. + """ + outputs: Optional[EndNodeDataOutputs] = None diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py new file mode 100644 index 0000000000..bd499543d9 --- /dev/null +++ b/api/core/workflow/nodes/llm/entities.py @@ -0,0 +1,8 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class LLMNodeData(BaseNodeData): + """ + LLM Node Data. + """ + pass diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 1c7277e942..e3ae9fc00f 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,9 +1,28 @@ -from typing import Optional +from typing import Optional, cast +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.llm.entities import LLMNodeData class LLMNode(BaseNode): + _node_data_cls = LLMNodeData + node_type = NodeType.LLM + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + pass + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 64687db042..0bd5f203bf 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,23 +1,9 @@ from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType class StartNodeData(BaseNodeData): """ - - title (string) 节点标题 - - desc (string) optional 节点描述 - - type (string) 节点类型,固定为 start - - variables (array[object]) 表单变量列表 - - type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义) - - label (string) 控件展示标签名 - - variable (string) 变量 key - - max_length (int) 最大长度,适用于 text-input 和 paragraph - - default (string) optional 默认值 - - required (bool) optional是否必填,默认 false - - hint (string) optional 提示信息 - - options (array[string]) 选项值(仅 select 可用) + Start Node Data """ - type: str = NodeType.START.value - variables: list[VariableEntity] = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 74d8541436..ce04031b04 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,9 +1,11 @@ -from typing import Optional +from typing import Optional, cast -from core.workflow.entities.node_entities import NodeType +from core.app.app_config.entities import VariableEntity +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData +from models.workflow import WorkflowNodeExecutionStatus class StartNode(BaseNode): @@ -11,12 +13,58 @@ class StartNode(BaseNode): node_type = NodeType.START def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> dict: + run_args: Optional[dict] = None) -> NodeRunResult: """ Run node :param variable_pool: variable pool :param run_args: run args :return: """ - pass + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + variables = node_data.variables + # Get cleaned inputs + cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=cleaned_inputs, + outputs=cleaned_inputs + ) + + def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + for variable_config in variables: + variable = variable_config.variable + + if variable not in user_inputs or not user_inputs[variable]: + if variable_config.required: + raise ValueError(f"Input form variable {variable} is required") + else: + filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" + continue + + value = user_inputs[variable] + + if value: + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if variable_config.type == VariableEntity.Type.SELECT: + options = variable_config.options if variable_config.options is not None else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if variable_config.max_length is not None: + max_length = variable_config.max_length + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value.replace('\x00', '') if value else None + + return filtered_inputs diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 8ab0eb4802..5423546957 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -3,6 +3,7 @@ import time from datetime import datetime from typing import Optional, Union +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -141,6 +142,7 @@ class WorkflowEngineManager: workflow_run_state = WorkflowRunState( workflow_run=workflow_run, start_at=time.perf_counter(), + user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, ) @@ -399,7 +401,9 @@ class WorkflowEngineManager: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool + variable_pool=workflow_run_state.variable_pool, + run_args=workflow_run_state.user_inputs + if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: @@ -492,7 +496,7 @@ class WorkflowEngineManager: workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(result.metadata) + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() From 2ad9c76093aa1ccb7ceb4702a5bc2854c711897d Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 16:31:35 +0800 Subject: [PATCH 119/200] modify migrations --- ...5564d_conversation_columns_set_nullable.py | 48 +++++++++++++++++++ .../versions/b289e2408ee2_add_workflow.py | 2 - 2 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py new file mode 100644 index 0000000000..f388b99b90 --- /dev/null +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -0,0 +1,48 @@ +"""conversation columns set nullable + +Revision ID: 42e85ed5564d +Revises: f9107f83abab +Create Date: 2024-03-07 08:30:29.133614 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '42e85ed5564d' +down_revision = 'f9107f83abab' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 5ae1e65611..cf8530dc67 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -78,8 +78,6 @@ def upgrade(): sa.Column('error', sa.Text(), nullable=True), sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), - sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), - sa.Column('currency', sa.String(length=255), nullable=True), sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False), From b174f852377e9c534cbc67c2dcd271364e487fc9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 17:15:46 +0800 Subject: [PATCH 120/200] fix bug --- api/controllers/console/app/workflow.py | 2 +- api/fields/app_fields.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 4f8df6bcec..5d70076821 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -65,7 +65,7 @@ class DraftWorkflowApi(Resource): return { "result": "success", - "updated_at": TimestampField().format(workflow.updated_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 69ab1d3e3e..ccb95ad573 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -48,7 +48,7 @@ app_detail_fields = { 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -68,7 +68,7 @@ app_partial_fields = { 'mode': fields.String, 'icon': fields.String, 'icon_background': fields.String, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), 'created_at': TimestampField } @@ -118,7 +118,7 @@ app_detail_fields_with_site = { 'icon_background': fields.String, 'enable_site': fields.Boolean, 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, 'created_at': TimestampField, From 1f986a3abbef7ae2cbcbdf0cd05acebeb48baeca Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 19:45:02 +0800 Subject: [PATCH 121/200] fix bugs --- api/controllers/console/app/workflow.py | 28 ++++-- .../advanced_chat/generate_task_pipeline.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/app/apps/chat/app_config_manager.py | 2 +- .../workflow_event_trigger_callback.py | 2 +- api/core/workflow/workflow_engine_manager.py | 99 +++++++++---------- .../versions/b289e2408ee2_add_workflow.py | 4 +- ...29b71023c_messages_columns_set_nullable.py | 41 ++++++++ api/models/model.py | 4 +- api/models/workflow.py | 6 +- 10 files changed, 118 insertions(+), 72 deletions(-) create mode 100644 api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5d70076821..8a68cafad8 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import Resource, marshal_with, reqparse @@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument('inputs', type=dict, location='json') + parser.add_argument('query', type=str, required=True, location='json', default='') + parser.add_argument('files', type=list, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') args = parser.parse_args() @@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource): args=args, invoke_from=InvokeFrom.DEBUGGER ) + + return compact_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - class DraftWorkflowRunApi(Resource): @setup_required @@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource): Run draft workflow """ parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() workflow_service = WorkflowService() @@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource): return workflow +def compact_response(response: Union[dict, Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + yield from response + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 84352f16c7..624a0f430a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 44fb5905b0..5d99ce6297 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index ac69a92823..553cf34ee9 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager): else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 57775f2cce..3d7a4035e7 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 5423546957..05a784c221 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -5,7 +5,7 @@ from typing import Optional, Union from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.nodes.base_node import BaseNode @@ -122,10 +122,10 @@ class WorkflowEngineManager: if 'nodes' not in graph or 'edges' not in graph: raise ValueError('nodes or edges not found in workflow graph') - if isinstance(graph.get('nodes'), list): + if not isinstance(graph.get('nodes'), list): raise ValueError('nodes in workflow graph must be a list') - if isinstance(graph.get('edges'), list): + if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') # init workflow run @@ -150,6 +150,7 @@ class WorkflowEngineManager: try: predecessor_node = None + has_entry_node = False while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( @@ -161,6 +162,8 @@ class WorkflowEngineManager: if not next_node: break + has_entry_node = True + # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -182,7 +185,7 @@ class WorkflowEngineManager: predecessor_node = next_node - if not predecessor_node and not next_node: + if not has_entry_node: self._workflow_run_failed( workflow_run_state=workflow_run_state, error='Start node not found in workflow graph.', @@ -219,38 +222,31 @@ class WorkflowEngineManager: :param callbacks: workflow callbacks :return: """ - try: - db.session.begin() + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .for_update() \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **system_inputs}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) - - db.session.add(workflow_run) - db.session.commit() - except: - db.session.rollback() - raise + db.session.add(workflow_run) + db.session.commit() if callbacks: for callback in callbacks: @@ -330,7 +326,7 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: - if node_config.get('type') == NodeType.START.value: + if node_config.get('data', {}).get('type', '') == NodeType.START.value: return StartNode(config=node_config) else: edges = graph.get('edges') @@ -368,7 +364,7 @@ class WorkflowEngineManager: return None # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) + target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( config=target_node_config, @@ -424,17 +420,18 @@ class WorkflowEngineManager: callbacks=callbacks ) - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) + if node_run_result.outputs: + for variable_key, variable_value in node_run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + variable_pool=workflow_run_state.variable_pool, + node_id=node.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) - if node_run_result.metadata.get('total_tokens'): - workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) + if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) return workflow_node_execution @@ -464,7 +461,6 @@ class WorkflowEngineManager: node_id=node.node_id, node_type=node.node_type.value, title=node.node_data.title, - type=node.node_type.value, status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=workflow_run.created_by_role, created_by=workflow_run.created_by @@ -493,10 +489,11 @@ class WorkflowEngineManager: """ workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(result.inputs) - workflow_node_execution.process_data = json.dumps(result.process_data) - workflow_node_execution.outputs = json.dumps(result.outputs) - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) + workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None + workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None + workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \ + if result.metadata else None workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index cf8530dc67..8fadf2dc6c 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -45,8 +45,8 @@ def upgrade(): sa.Column('node_id', sa.String(length=255), nullable=False), sa.Column('node_type', sa.String(length=255), nullable=False), sa.Column('title', sa.String(length=255), nullable=False), - sa.Column('inputs', sa.Text(), nullable=False), - sa.Column('process_data', sa.Text(), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('process_data', sa.Text(), nullable=True), sa.Column('outputs', sa.Text(), nullable=True), sa.Column('status', sa.String(length=255), nullable=False), sa.Column('error', sa.Text(), nullable=True), diff --git a/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py new file mode 100644 index 0000000000..ee81fdab28 --- /dev/null +++ b/api/migrations/versions/b5429b71023c_messages_columns_set_nullable.py @@ -0,0 +1,41 @@ +"""messages columns set nullable + +Revision ID: b5429b71023c +Revises: 42e85ed5564d +Create Date: 2024-03-07 09:52:00.846136 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'b5429b71023c' +down_revision = '42e85ed5564d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index c579c3dee8..6856c4e1b0 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -585,8 +585,8 @@ class Message(db.Model): id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(UUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=False) - model_id = db.Column(db.String(255), nullable=False) + model_provider = db.Column(db.String(255), nullable=True) + model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) inputs = db.Column(db.JSON) diff --git a/api/models/workflow.py b/api/models/workflow.py index 032134a0d1..0883d0ef13 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -138,7 +138,7 @@ class Workflow(db.Model): if 'nodes' not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) + start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) if not start_node: return [] @@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model): node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text, nullable=False) - process_data = db.Column(db.Text, nullable=False) + inputs = db.Column(db.Text) + process_data = db.Column(db.Text) outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) From 1914dfea7705c7d3d52059b52ab476941e745971 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 20:50:02 +0800 Subject: [PATCH 122/200] fix bugs --- .../advanced_chat/generate_task_pipeline.py | 24 ++++++++++++-- .../nodes/direct_answer/direct_answer_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 33 ++++++++++++++++++- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 624a0f430a..c1076fa947 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -47,6 +47,7 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} usage: LLMUsage + workflow_run_id: Optional[str] = None class AdvancedChatAppGenerateTaskPipeline: @@ -110,6 +111,8 @@ class AdvancedChatAppGenerateTaskPipeline: } self._task_state.answer = annotation.content + elif isinstance(event, QueueWorkflowStartedEvent): + self._task_state.workflow_run_id = event.workflow_run_id elif isinstance(event, QueueNodeFinishedEvent): workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: @@ -171,6 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: break elif isinstance(event, QueueWorkflowStartedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) + self._task_state.workflow_run_id = workflow_run.id response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -234,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline: if isinstance(event, QueueWorkflowFinishedEvent): workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -389,7 +393,13 @@ class AdvancedChatAppGenerateTaskPipeline: :param workflow_run_id: workflow run id :return: """ - return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + if workflow_run: + # Because the workflow_run will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_run) + return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: """ @@ -397,7 +407,14 @@ class AdvancedChatAppGenerateTaskPipeline: :param workflow_node_execution_id: workflow node execution id :return: """ - return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) + if workflow_node_execution: + # Because the workflow_node_execution will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_node_execution) + return workflow_node_execution def _save_message(self) -> None: """ @@ -408,6 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 80ecdf7757..bc6e4bd800 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -48,7 +48,7 @@ class DirectAnswerNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variable_values, - output={ + outputs={ "answer": answer } ) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 05a784c221..19dac76631 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -33,6 +33,7 @@ from models.workflow import ( WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, + WorkflowType, ) node_classes = { @@ -268,7 +269,7 @@ class WorkflowEngineManager: # fetch last workflow_node_executions last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] if last_workflow_node_execution: - workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) + workflow_run.outputs = last_workflow_node_execution.outputs workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.total_tokens = workflow_run_state.total_tokens @@ -390,6 +391,7 @@ class WorkflowEngineManager: workflow_run_state=workflow_run_state, node=node, predecessor_node=predecessor_node, + callbacks=callbacks ) # add to workflow node executions @@ -412,6 +414,9 @@ class WorkflowEngineManager: ) raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") + # set end node output if in chat + self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) + # node run success self._workflow_node_execution_success( workflow_node_execution=workflow_node_execution, @@ -529,6 +534,32 @@ class WorkflowEngineManager: return workflow_node_execution + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, + node: BaseNode, + node_run_result: NodeRunResult): + """ + Set end node output if in chat + :param workflow_run_state: workflow run state + :param node: current node + :param node_run_result: node run result + :return: + """ + if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2] + if workflow_node_execution_before_end: + if workflow_node_execution_before_end.node_type == NodeType.LLM.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text') + elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value: + if not node_run_result.outputs: + node_run_result.outputs = {} + + node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer') + + return node_run_result + def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], From 1a0b6adc2ced6860a477570d0d01b112fc9dd354 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 16:44:42 +0800 Subject: [PATCH 123/200] fix stream bugs --- api/core/app/apps/advanced_chat/app_generator.py | 2 +- .../app/apps/advanced_chat/generate_task_pipeline.py | 2 +- .../advanced_chat/workflow_event_trigger_callback.py | 2 +- api/core/app/apps/base_app_queue_manager.py | 9 +++++++-- api/core/app/apps/workflow/generate_task_pipeline.py | 2 +- .../app/apps/workflow/workflow_event_trigger_callback.py | 2 +- api/core/app/entities/queue_entities.py | 2 +- 7 files changed, 13 insertions(+), 8 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a19a5c8f67..92286c9af0 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -54,7 +54,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): inputs = args['inputs'] extras = { - "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False } # get conversation diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c1076fa947..9c06f516a5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -346,7 +346,7 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueTextChunkEvent): - delta_text = event.chunk_text + delta_text = event.text if delta_text is None: continue diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 5d99ce6297..8f72305bb1 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -76,7 +76,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): streamable_node_ids = [] end_node_ids = [] for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('type') == NodeType.END.value: end_node_ids.append(node_config.get('id')) for edge_config in graph.get('edges'): diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 0391599040..289567fe5d 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -15,6 +15,7 @@ from core.app.entities.queue_entities import ( QueueMessageEndEvent, QueuePingEvent, QueueStopEvent, + QueueWorkflowFinishedEvent, ) from extensions.ext_redis import redis_client @@ -36,7 +37,8 @@ class AppQueueManager: self._invoke_from = invoke_from user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") + redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, + f"{user_prefix}-{self._user_id}") q = queue.Queue() @@ -106,7 +108,10 @@ class AppQueueManager: self._q.put(message) - if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent): + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowFinishedEvent): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index df83ad634e..bcd5a4ba3d 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -248,7 +248,7 @@ class WorkflowAppGenerateTaskPipeline: yield self._yield_response(workflow_run_response) elif isinstance(event, QueueTextChunkEvent): - delta_text = event.chunk_text + delta_text = event.text if delta_text is None: continue diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 3d7a4035e7..12b93518ed 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -76,7 +76,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): streamable_node_ids = [] end_node_ids = [] for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('type') == NodeType.END.value: if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': end_node_ids.append(node_config.get('id')) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index e5c6a8eff9..38f9638eaa 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -48,7 +48,7 @@ class QueueTextChunkEvent(AppQueueEvent): QueueTextChunkEvent entity """ event = QueueEvent.TEXT_CHUNK - chunk_text: str + text: str class QueueAgentMessageEvent(AppQueueEvent): From c152d55f68f1da84b56ed50e01072b16683eaea6 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 18:37:08 +0800 Subject: [PATCH 124/200] fix workflow app bugs --- api/controllers/console/app/workflow.py | 8 +-- .../advanced_chat/generate_task_pipeline.py | 55 ++++++++++--------- .../apps/message_based_app_queue_manager.py | 3 +- .../app/apps/workflow/app_queue_manager.py | 3 +- .../apps/workflow/generate_task_pipeline.py | 34 ++++++++++-- api/core/app/entities/queue_entities.py | 17 +++++- api/models/workflow.py | 2 +- 7 files changed, 79 insertions(+), 43 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8a68cafad8..30d383ec02 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -129,18 +129,14 @@ class DraftWorkflowRunApi(Resource): args=args, invoke_from=InvokeFrom.DEBUGGER ) + + return compact_response(response) except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") raise InternalServerError() - def generate() -> Generator: - yield from response - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - class WorkflowTaskStopApi(Resource): @setup_required diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 9c06f516a5..db22607146 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -235,36 +235,39 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueWorkflowFinishedEvent): + if isinstance(event, QueueStopEvent): + workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) + else: workflow_run = self._get_workflow_run(event.workflow_run_id) - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - data = self._error_to_stream_response_data(self._handle_error(err_event)) - yield self._yield_response(data) - break - workflow_run_response = { - 'event': 'workflow_finished', - 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, - 'data': { - 'id': workflow_run.id, - 'workflow_id': workflow_run.workflow_id, - 'status': workflow_run.status, - 'outputs': workflow_run.outputs_dict, - 'error': workflow_run.error, - 'elapsed_time': workflow_run.elapsed_time, - 'total_tokens': workflow_run.total_tokens, - 'total_steps': workflow_run.total_steps, - 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) - } + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + else: + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + data = self._error_to_stream_response_data(self._handle_error(err_event)) + yield self._yield_response(data) + break + + workflow_run_response = { + 'event': 'workflow_finished', + 'task_id': self._application_generate_entity.task_id, + 'workflow_run_id': event.workflow_run_id, + 'data': { + 'id': workflow_run.id, + 'workflow_id': workflow_run.workflow_id, + 'status': workflow_run.status, + 'outputs': workflow_run.outputs_dict, + 'error': workflow_run.error, + 'elapsed_time': workflow_run.elapsed_time, + 'total_tokens': workflow_run.total_tokens, + 'total_steps': workflow_run.total_steps, + 'created_at': int(workflow_run.created_at.timestamp()), + 'finished_at': int(workflow_run.finished_at.timestamp()) } + } - yield self._yield_response(workflow_run_response) + yield self._yield_response(workflow_run_response) # response moderation if self._output_moderation_handler: diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index ed9475502d..13644c99ae 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -2,6 +2,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueMessage, ) @@ -20,7 +21,7 @@ class MessageBasedAppQueueManager(AppQueueManager): self._message_id = str(message_id) def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return QueueMessage( + return MessageQueueMessage( task_id=self._task_id, message_id=self._message_id, conversation_id=self._conversation_id, diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 0f9b0a1c78..5cf1e58913 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -3,6 +3,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueMessage, + WorkflowQueueMessage, ) @@ -16,7 +17,7 @@ class WorkflowAppQueueManager(AppQueueManager): self._app_mode = app_mode def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return QueueMessage( + return WorkflowQueueMessage( task_id=self._task_id, app_mode=self._app_mode, event=event diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index bcd5a4ba3d..a48640766a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -86,7 +86,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) @@ -136,12 +136,11 @@ class WorkflowAppGenerateTaskPipeline: break elif isinstance(event, QueueWorkflowStartedEvent): self._task_state.workflow_run_id = event.workflow_run_id - workflow_run = self._get_workflow_run(event.workflow_run_id) response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -198,7 +197,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_run = self._get_workflow_run(event.workflow_run_id) if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs + outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') else: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -228,6 +227,9 @@ class WorkflowAppGenerateTaskPipeline: yield self._yield_response(replace_response) + # save workflow app log + self._save_workflow_app_log() + workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, @@ -295,7 +297,13 @@ class WorkflowAppGenerateTaskPipeline: :param workflow_run_id: workflow run id :return: """ - return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + if workflow_run: + # Because the workflow_run will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_run) + return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: """ @@ -303,7 +311,21 @@ class WorkflowAppGenerateTaskPipeline: :param workflow_node_execution_id: workflow node execution id :return: """ - return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) + if workflow_node_execution: + # Because the workflow_node_execution will be modified in the sub-thread, + # and the first query in the main thread will cache the entity, + # you need to expire the entity after the query + db.session.expire(workflow_node_execution) + return workflow_node_execution + + def _save_workflow_app_log(self) -> None: + """ + Save workflow app log. + :return: + """ + pass # todo def _handle_chunk(self, text: str) -> dict: """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 38f9638eaa..67ed13d721 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -176,7 +176,20 @@ class QueueMessage(BaseModel): QueueMessage entity """ task_id: str - message_id: str - conversation_id: str app_mode: str event: AppQueueEvent + + +class MessageQueueMessage(QueueMessage): + """ + MessageQueueMessage entity + """ + message_id: str + conversation_id: str + + +class WorkflowQueueMessage(QueueMessage): + """ + WorkflowQueueMessage entity + """ + pass diff --git a/api/models/workflow.py b/api/models/workflow.py index 0883d0ef13..9768c364dd 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -143,7 +143,7 @@ class Workflow(db.Model): return [] # get user_input_form from start node - return start_node.get('variables', []) + return start_node.get('data', {}).get('variables', []) class WorkflowRunTriggeredFrom(Enum): From 736e386f15bba02e55b958682c17531eceda5ee6 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 21:35:58 +0800 Subject: [PATCH 125/200] fix: bugs --- api/core/app/apps/agent_chat/app_config_manager.py | 2 +- api/core/app/apps/completion/app_config_manager.py | 2 +- api/services/completion_service.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 57214f924a..232211c18b 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -52,7 +52,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a82e68a337..b98a4c16aa 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -37,7 +37,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): else: config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG - if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: + if config_from != EasyUIBasedAppModelConfigFrom.ARGS: app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 4e3c4e19f6..eb31ccbb3b 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -30,16 +30,16 @@ class CompletionService: invoke_from=invoke_from, stream=streaming ) - elif app_model.mode == AppMode.CHAT.value: - return ChatAppGenerator().generate( + elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + return AgentChatAppGenerator().generate( app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming ) - elif app_model.mode == AppMode.AGENT_CHAT.value: - return AgentChatAppGenerator().generate( + elif app_model.mode == AppMode.CHAT.value: + return ChatAppGenerator().generate( app_model=app_model, user=user, args=args, From cb02b1e12e316e6dfd0c995cc71b98b0f995adec Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 23:52:51 +0800 Subject: [PATCH 126/200] feat: code --- api/.env.example | 4 + api/config.py | 7 +- api/core/workflow/nodes/code/code_executor.py | 70 +++++++ api/core/workflow/nodes/code/code_node.py | 180 +++++++++++++++++- api/core/workflow/nodes/code/entities.py | 19 ++ .../workflow/nodes/code/python_template.py | 55 ++++++ 6 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 api/core/workflow/nodes/code/code_executor.py create mode 100644 api/core/workflow/nodes/code/entities.py create mode 100644 api/core/workflow/nodes/code/python_template.py diff --git a/api/.env.example b/api/.env.example index 32d89d4287..4a3b1d65af 100644 --- a/api/.env.example +++ b/api/.env.example @@ -132,3 +132,7 @@ SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= BATCH_UPLOAD_LIMIT=10 + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTINO_API_KEY= diff --git a/api/config.py b/api/config.py index a978a099b9..a6bc731b82 100644 --- a/api/config.py +++ b/api/config.py @@ -59,7 +59,9 @@ DEFAULTS = { 'CAN_REPLACE_LOGO': 'False', 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', - 'BATCH_UPLOAD_LIMIT': 20 + 'BATCH_UPLOAD_LIMIT': 20, + 'CODE_EXECUTION_ENDPOINT': '', + 'CODE_EXECUTION_API_KEY': '' } @@ -293,6 +295,9 @@ class Config: self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') + self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/workflow/nodes/code/code_executor.py new file mode 100644 index 0000000000..3ecd7cfd89 --- /dev/null +++ b/api/core/workflow/nodes/code/code_executor.py @@ -0,0 +1,70 @@ +from os import environ + +from httpx import post +from yarl import URL +from pydantic import BaseModel + +from core.workflow.nodes.code.python_template import PythonTemplateTransformer + +# Code Executor +CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') +CODE_EXECUTION_API_KEY = environ.get('CODE_EXECUTION_API_KEY', '') + +class CodeExecutionException(Exception): + pass + +class CodeExecutionResponse(BaseModel): + class Data(BaseModel): + stdout: str + stderr: str + + code: int + message: str + data: Data + +class CodeExecutor: + @classmethod + def execute_code(cls, language: str, code: str, inputs: dict) -> dict: + """ + Execute code + :param language: code language + :param code: code + :param inputs: inputs + :return: + """ + runner = PythonTemplateTransformer.transform_caller(code, inputs) + + url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + headers = { + 'X-Api-Key': CODE_EXECUTION_API_KEY + } + data = { + 'language': language, + 'code': runner, + } + + try: + response = post(str(url), json=data, headers=headers) + if response.status_code == 503: + raise CodeExecutionException('Code execution service is unavailable') + elif response.status_code != 200: + raise Exception('Failed to execute code') + except CodeExecutionException as e: + raise e + except Exception: + raise CodeExecutionException('Failed to execute code') + + try: + response = response.json() + except: + raise CodeExecutionException('Failed to parse response') + + response = CodeExecutionResponse(**response) + + if response.code != 0: + raise CodeExecutionException(response.message) + + if response.data.stderr: + raise CodeExecutionException(response.data.stderr) + + return PythonTemplateTransformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 7e69f91d11..dc69fdc84a 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,9 +1,23 @@ -from typing import Optional +from typing import Optional, cast, Union +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.code_executor import CodeExecutor, CodeExecutionException +from models.workflow import WorkflowNodeExecutionStatus +MAX_NUMBER = 2 ** 63 - 1 +MIN_NUMBER = -2 ** 63 +MAX_PRECISION = 20 +MAX_DEPTH = 5 +MAX_STRING_LENGTH = 1000 +MAX_STRING_ARRAY_LENGTH = 30 class CodeNode(BaseNode): + _node_data_cls = CodeNodeData + node_type = NodeType.CODE + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -62,3 +76,167 @@ class CodeNode(BaseNode): ] } } + + def _run(self, variable_pool: Optional[VariablePool] = None, + run_args: Optional[dict] = None) -> NodeRunResult: + """ + Run code + :param variable_pool: variable pool + :param run_args: run args + :return: + """ + node_data = self.node_data + node_data: CodeNodeData = cast(self._node_data_cls, node_data) + + # SINGLE DEBUG NOT IMPLEMENTED YET + if variable_pool is None and run_args: + raise ValueError("Not support single step debug.") + + # Get code language + code_language = node_data.code_language + code = node_data.code + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + + # Run code + try: + result = CodeExecutor.execute_code( + language=code_language, + code=code, + inputs=variables + ) + except CodeExecutionException as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + # Transform result + result = self._transform_result(result, node_data.outputs) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=result + ) + + def _check_string(self, value: str, variable: str) -> str: + """ + Check string + :param value: value + :param variable: variable + :param max_length: max length + :return: + """ + if not isinstance(value, str): + raise ValueError(f"{variable} in input form must be a string") + + if len(value) > MAX_STRING_LENGTH: + raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') + + return value.replace('\x00', '') + + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + """ + Check number + :param value: value + :param variable: variable + :return: + """ + if not isinstance(value, (int, float)): + raise ValueError(f"{variable} in input form must be a number") + + if value > MAX_NUMBER or value < MIN_NUMBER: + raise ValueError(f'{variable} in input form is out of range.') + + if isinstance(value, float): + value = round(value, MAX_PRECISION) + + return value + + def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + prefix: str = '', + depth: int = 1) -> dict: + """ + Transform result + :param result: result + :param output_schema: output schema + :return: + """ + if depth > MAX_DEPTH: + raise ValueError("Depth limit reached, object too deep.") + + transformed_result = {} + for output_name, output_config in output_schema.items(): + if output_config.type == 'object': + # check if output is object + if not isinstance(result.get(output_name), dict): + raise ValueError( + f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' + ) + + transformed_result[output_name] = self._transform_result( + result=result[output_name], + output_schema=output_config.children, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif output_config.type == 'number': + # check if number available + transformed_result[output_name] = self._check_number( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + + transformed_result[output_name] = result[output_name] + elif output_config.type == 'string': + # check if string available + transformed_result[output_name] = self._check_string( + value=result[output_name], + variable=f'{prefix}.{output_name}' if prefix else output_name, + ) + elif output_config.type == 'array[number]': + # check if array of number available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + transformed_result[output_name] = [ + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + elif output_config.type == 'array[string]': + # check if array of string available + if not isinstance(result[output_name], list): + raise ValueError( + f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' + ) + + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' + ) + + transformed_result[output_name] = [ + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + for value in result[output_name] + ] + else: + raise ValueError(f'Output type {output_config.type} is not supported.') + + return transformed_result \ No newline at end of file diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py new file mode 100644 index 0000000000..731b00f8c8 --- /dev/null +++ b/api/core/workflow/nodes/code/entities.py @@ -0,0 +1,19 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + +from pydantic import BaseModel +from typing import Literal, Union + +class CodeNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Output(BaseModel): + type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] + children: Union[None, dict[str, 'Output']] + + variables: list[VariableSelector] + answer: str + code_language: str + code: str + outputs: dict[str, Output] diff --git a/api/core/workflow/nodes/code/python_template.py b/api/core/workflow/nodes/code/python_template.py new file mode 100644 index 0000000000..03dfee36f3 --- /dev/null +++ b/api/core/workflow/nodes/code/python_template.py @@ -0,0 +1,55 @@ +import json +import re + +PYTHON_RUNNER = """# declare main function here +{{code}} + +# execute main function, and return the result +# inputs is a dict, and it +output = main(**{{inputs}}) + +# convert output to json and print +result = ''' +<> +{output} +<> +''' + +print(result) +""" + + +class PythonTemplateTransformer: + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + + result = result.group(1) + return json.loads(result) From 5596b3b00b0dbbc3658b70e16bc9b64bd27fa682 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 8 Mar 2024 23:53:18 +0800 Subject: [PATCH 127/200] fix: linter --- api/core/workflow/nodes/code/code_executor.py | 2 +- api/core/workflow/nodes/code/code_node.py | 8 ++++---- api/core/workflow/nodes/code/entities.py | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/workflow/nodes/code/code_executor.py index 3ecd7cfd89..058ee83d46 100644 --- a/api/core/workflow/nodes/code/code_executor.py +++ b/api/core/workflow/nodes/code/code_executor.py @@ -1,8 +1,8 @@ from os import environ from httpx import post -from yarl import URL from pydantic import BaseModel +from yarl import URL from core.workflow.nodes.code.python_template import PythonTemplateTransformer diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index dc69fdc84a..32f6776850 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,10 +1,10 @@ -from typing import Optional, cast, Union +from typing import Optional, Union, cast + from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool - from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.code.code_executor import CodeExecutionException, CodeExecutor from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.code.code_executor import CodeExecutor, CodeExecutionException from models.workflow import WorkflowNodeExecutionStatus MAX_NUMBER = 2 ** 63 - 1 @@ -151,7 +151,7 @@ class CodeNode(BaseNode): :param variable: variable :return: """ - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): raise ValueError(f"{variable} in input form must be a number") if value > MAX_NUMBER or value < MIN_NUMBER: diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 731b00f8c8..2212d77e2d 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,8 +1,10 @@ +from typing import Literal, Union + +from pydantic import BaseModel + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector -from pydantic import BaseModel -from typing import Literal, Union class CodeNodeData(BaseNodeData): """ From fc573564b4f321233b2ddc1b3bf642c2834a7762 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 8 Mar 2024 23:59:09 +0800 Subject: [PATCH 128/200] refactor workflow runner --- api/controllers/console/app/workflow.py | 7 +- .../app/apps/advanced_chat/app_generator.py | 31 +- api/core/app/apps/advanced_chat/app_runner.py | 33 +- .../advanced_chat/generate_task_pipeline.py | 220 +++++++++--- .../workflow_event_trigger_callback.py | 83 ++++- api/core/app/apps/agent_chat/app_generator.py | 4 +- api/core/app/apps/base_app_queue_manager.py | 27 +- api/core/app/apps/chat/app_generator.py | 4 +- api/core/app/apps/completion/app_generator.py | 4 +- .../app/apps/message_based_app_generator.py | 4 +- .../apps/message_based_app_queue_manager.py | 35 +- api/core/app/apps/workflow/app_generator.py | 14 +- .../app/apps/workflow/app_queue_manager.py | 30 +- api/core/app/apps/workflow/app_runner.py | 33 +- .../apps/workflow/generate_task_pipeline.py | 207 +++++++++--- .../workflow_event_trigger_callback.py | 83 ++++- .../workflow_based_generate_task_pipeline.py | 202 +++++++++++ api/core/app/entities/queue_entities.py | 66 +++- .../callbacks/base_workflow_callback.py | 44 ++- .../workflow/entities/workflow_entities.py | 26 +- .../nodes/direct_answer/direct_answer_node.py | 2 +- api/core/workflow/workflow_engine_manager.py | 319 ++++-------------- api/services/workflow_service.py | 19 +- 23 files changed, 996 insertions(+), 501 deletions(-) create mode 100644 api/core/app/apps/workflow_based_generate_task_pipeline.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 30d383ec02..5f03a7cd37 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -147,9 +147,12 @@ class WorkflowTaskStopApi(Resource): """ Stop workflow task """ - # TODO workflow_service = WorkflowService() - workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user) + workflow_service.stop_workflow_task( + task_id=task_id, + user=current_user, + invoke_from=InvokeFrom.DEBUGGER + ) return { "result": "success" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 92286c9af0..ed45e2ba8a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -11,7 +11,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom @@ -123,11 +123,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): worker_thread.start() # return response or stream generator - return self._handle_response( + return self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, conversation=conversation, message=message, + user=user, stream=stream ) @@ -159,7 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -177,33 +179,40 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): finally: db.session.remove() - def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - stream: bool = False) -> Union[dict, Generator]: + def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False) -> Union[dict, Generator]: """ Handle response. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager :param conversation: conversation :param message: message + :param user: account or end user :param stream: is stream :return: """ # init generate task pipeline generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, + user=user, + stream=stream ) try: - return generate_task_pipeline.process(stream=stream) + return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 077f0c2de0..3279e00355 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging import time -from typing import cast +from typing import Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback @@ -8,16 +8,14 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser, Message -from models.workflow import WorkflowRunTriggeredFrom +from models.model import App, Conversation, Message +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -46,7 +44,7 @@ class AdvancedChatAppRunner(AppRunner): if not app_record: raise ValueError("App not found") - workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: raise ValueError("Workflow not initialized") @@ -74,19 +72,10 @@ class AdvancedChatAppRunner(AppRunner): ): return - # fetch user - if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: - user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() - else: - user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, - user=user, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, @@ -99,6 +88,20 @@ class AdvancedChatAppRunner(AppRunner): )] ) + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index db22607146..18bc9c8008 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,9 +4,10 @@ import time from collections.abc import Generator from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Extra from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -16,25 +17,35 @@ from core.app.entities.queue_entities import ( QueueErrorEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageFile -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus +from models.account import Account +from models.model import Conversation, EndUser, Message, MessageFile +from models.workflow import ( + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -47,41 +58,63 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} usage: LLMUsage - workflow_run_id: Optional[str] = None + + workflow_run: Optional[WorkflowRun] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + current_node_execution: Optional[WorkflowNodeExecution] = None + current_node_execution_start_at: Optional[float] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True -class AdvancedChatAppGenerateTaskPipeline: +class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, conversation: Conversation, - message: Message) -> None: + message: Message, + user: Union[Account, EndUser], + stream: bool) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager :param conversation: conversation :param message: message + :param user: user + :param stream: stream """ self._application_generate_entity = application_generate_entity + self._workflow = workflow self._queue_manager = queue_manager self._conversation = conversation self._message = message + self._user = user self._task_state = TaskState( usage=LLMUsage.empty_usage() ) self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() + self._stream = stream - def process(self, stream: bool) -> Union[dict, Generator]: + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ - if stream: + if self._stream: return self._process_stream_response() else: return self._process_blocking_response() @@ -112,22 +145,17 @@ class AdvancedChatAppGenerateTaskPipeline: self._task_state.answer = annotation.content elif isinstance(event, QueueWorkflowStartedEvent): - self._task_state.workflow_run_id = event.workflow_run_id - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) - if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueWorkflowFinishedEvent): - workflow_run = self._get_workflow_run(event.workflow_run_id) - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs - self._task_state.answer = outputs.get('text', '') - else: - raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run + + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: + raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) # response moderation if self._output_moderation_handler: @@ -173,8 +201,9 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._get_workflow_run(event.workflow_run_id) - self._task_state.workflow_run_id = workflow_run.id + self._on_workflow_start() + workflow_run = self._task_state.workflow_run + response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -188,7 +217,9 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + self._on_node_start(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_started', 'task_id': self._application_generate_entity.task_id, @@ -204,8 +235,10 @@ class AdvancedChatAppGenerateTaskPipeline: } yield self._yield_response(response) - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + workflow_node_execution = self._task_state.current_node_execution + if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: outputs = workflow_node_execution.outputs_dict @@ -234,16 +267,11 @@ class AdvancedChatAppGenerateTaskPipeline: } yield self._yield_response(response) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: + if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) data = self._error_to_stream_response_data(self._handle_error(err_event)) yield self._yield_response(data) @@ -252,7 +280,7 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -390,6 +418,102 @@ class AdvancedChatAppGenerateTaskPipeline: else: continue + def _on_workflow_start(self) -> None: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.QUERY: self._message.query, + SystemVariable.FILES: self._application_generate_entity.files, + SystemVariable.CONVERSATION: self._conversation.id, + } + ) + + self._task_state.workflow_run = workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=self._task_state.workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + self._task_state.current_node_execution = workflow_node_execution + self._task_state.current_node_execution_start_at = time.perf_counter() + self._task_state.total_steps += 1 + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + + if workflow_node_execution.node_type == NodeType.LLM.value: + outputs = workflow_node_execution.outputs_dict + usage_dict = outputs.get('usage', {}) + self._task_state.metadata['usage'] = usage_dict + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + error=event.error + ) + + self._task_state.current_node_execution = workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + workflow_run = self._workflow_run_success( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=self._task_state.current_node_execution.outputs + if self._task_state.current_node_execution else None + ) + + self._task_state.workflow_run = workflow_run + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Get workflow run. @@ -397,11 +521,6 @@ class AdvancedChatAppGenerateTaskPipeline: :return: """ workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - if workflow_run: - # Because the workflow_run will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_run) return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: @@ -412,11 +531,6 @@ class AdvancedChatAppGenerateTaskPipeline: """ workflow_node_execution = (db.session.query(WorkflowNodeExecution) .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - if workflow_node_execution: - # Because the workflow_node_execution will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_node_execution) return workflow_node_execution def _save_message(self) -> None: @@ -428,7 +542,7 @@ class AdvancedChatAppGenerateTaskPipeline: self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.workflow_run_id = self._task_state.workflow_run_id + self._message.workflow_run_id = self._task_state.workflow_run.id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 8f72305bb1..d9c8a2c96d 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -1,14 +1,19 @@ +from typing import Optional + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun +from models.workflow import Workflow class WorkflowEventTriggerCallback(BaseWorkflowCallback): @@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): self._queue_manager = queue_manager self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ self._queue_manager.publish( - QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowStartedEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_succeeded(self) -> None: """ - Workflow run finished + Workflow run succeeded """ self._queue_manager.publish( - QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowSucceededEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + self._queue_manager.publish( + QueueWorkflowFailedEvent( + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ self._queue_manager.publish( - QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: """ - Workflow node execute finished + Workflow node execute succeeded """ self._queue_manager.publish( - QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: + """ + Workflow node execute failed + """ + self._queue_manager.publish( + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + error=error + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 6d27620a09..700a340c96 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -11,7 +11,7 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelCo from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -177,7 +177,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 289567fe5d..43a44819f9 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -11,11 +11,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueErrorEvent, - QueueMessage, - QueueMessageEndEvent, QueuePingEvent, QueueStopEvent, - QueueWorkflowFinishedEvent, ) from extensions.ext_redis import redis_client @@ -103,22 +100,16 @@ class AppQueueManager: :return: """ self._check_for_sqlalchemy_models(event.dict()) - - message = self.construct_queue_message(event) - - self._q.put(message) - - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowFinishedEvent): - self.stop_listen() - - if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise ConversationTaskStoppedException() + self._publish(event, pub_from) @abstractmethod - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ raise NotImplementedError @classmethod @@ -182,5 +173,5 @@ class AppQueueManager: "that cause thread safety issues is not allowed.") -class ConversationTaskStoppedException(Exception): +class GenerateTaskStoppedException(Exception): pass diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 7ddf8dfe32..317d045c04 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -9,7 +9,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator @@ -177,7 +177,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 7150bee3ce..b948938aac 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -9,7 +9,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.message_based_app_generator import MessageBasedAppGenerator @@ -166,7 +166,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): queue_manager=queue_manager, message=message ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 3dee68b5e1..0e76c96ff7 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -7,7 +7,7 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, @@ -60,7 +60,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return generate_task_pipeline.process(stream=stream) except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 13644c99ae..6d0a71f495 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,9 +1,14 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, + QueueErrorEvent, QueueMessage, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, ) @@ -28,3 +33,31 @@ class MessageBasedAppQueueManager(AppQueueManager): app_mode=self._app_mode, event=event ) + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = MessageQueueMessage( + task_id=self._task_id, + message_id=self._message_id, + conversation_id=self._conversation_id, + app_mode=self._app_mode, + event=event + ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() + diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 891ca4c2be..d3303047ca 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -9,7 +9,7 @@ from pydantic import ValidationError from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -95,7 +95,9 @@ class WorkflowAppGenerator(BaseAppGenerator): # return response or stream generator return self._handle_response( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, + user=user, stream=stream ) @@ -117,7 +119,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager ) - except ConversationTaskStoppedException: + except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -136,19 +138,25 @@ class WorkflowAppGenerator(BaseAppGenerator): db.session.remove() def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, + user: Union[Account, EndUser], stream: bool = False) -> Union[dict, Generator]: """ Handle response. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager + :param user: account or end user :param stream: is stream :return: """ # init generate task pipeline generate_task_pipeline = WorkflowAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, + workflow=workflow, queue_manager=queue_manager, + user=user, stream=stream ) @@ -156,7 +164,7 @@ class WorkflowAppGenerator(BaseAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise ConversationTaskStoppedException() + raise GenerateTaskStoppedException() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 5cf1e58913..f448138b53 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,8 +1,12 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, - QueueMessage, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowSucceededEvent, WorkflowQueueMessage, ) @@ -16,9 +20,27 @@ class WorkflowAppQueueManager(AppQueueManager): self._app_mode = app_mode - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return WorkflowQueueMessage( + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage( task_id=self._task_id, app_mode=self._app_mode, event=event ) + + self._q.put(message) + + if isinstance(event, QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedException() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 132282ffe3..59a385cb38 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,13 +1,12 @@ import logging import time -from typing import cast +from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, - InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent @@ -16,9 +15,8 @@ from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser -from models.workflow import WorkflowRunTriggeredFrom +from models.model import App +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -43,7 +41,7 @@ class WorkflowAppRunner: if not app_record: raise ValueError("App not found") - workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) + workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: raise ValueError("Workflow not initialized") @@ -59,19 +57,10 @@ class WorkflowAppRunner: ): return - # fetch user - if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]: - user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first() - else: - user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, - user=user, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files @@ -82,6 +71,20 @@ class WorkflowAppRunner: )] ) + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == workflow_id + ).first() + + # return workflow + return workflow + def handle_input_moderation(self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: WorkflowAppGenerateEntity, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a48640766a..721124c4c5 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,28 +4,35 @@ import time from collections.abc import Generator from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Extra from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline from core.app.entities.app_invoke_entities import ( + InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import ( QueueErrorEvent, QueueMessageReplaceEvent, - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration +from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable from extensions.ext_database import db -from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus +from models.account import Account +from models.model import EndUser +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -36,24 +43,44 @@ class TaskState(BaseModel): """ answer: str = "" metadata: dict = {} - workflow_run_id: Optional[str] = None + + workflow_run: Optional[WorkflowRun] = None + start_at: Optional[float] = None + total_tokens: int = 0 + total_steps: int = 0 + + current_node_execution: Optional[WorkflowNodeExecution] = None + current_node_execution_start_at: Optional[float] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True -class WorkflowAppGenerateTaskPipeline: +class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, queue_manager: AppQueueManager, + user: Union[Account, EndUser], stream: bool) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity + :param workflow: workflow :param queue_manager: queue manager + :param user: user + :param stream: is stream """ self._application_generate_entity = application_generate_entity + self._workflow = workflow self._queue_manager = queue_manager + self._user = user self._task_state = TaskState() self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() @@ -79,17 +106,15 @@ class WorkflowAppGenerateTaskPipeline: if isinstance(event, QueueErrorEvent): raise self._handle_error(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) + elif isinstance(event, QueueWorkflowStartedEvent): + self._on_workflow_start() + elif isinstance(event, QueueNodeStartedEvent): + self._on_node_start(event) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run # response moderation if self._output_moderation_handler: @@ -100,10 +125,12 @@ class WorkflowAppGenerateTaskPipeline: public_event=False ) + # save workflow app log + self._save_workflow_app_log() + response = { - 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -135,8 +162,9 @@ class WorkflowAppGenerateTaskPipeline: yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._task_state.workflow_run_id = event.workflow_run_id - workflow_run = self._get_workflow_run(event.workflow_run_id) + self._on_workflow_start() + workflow_run = self._task_state.workflow_run + response = { 'event': 'workflow_started', 'task_id': self._application_generate_entity.task_id, @@ -150,7 +178,9 @@ class WorkflowAppGenerateTaskPipeline: yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + self._on_node_start(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_started', 'task_id': self._application_generate_entity.task_id, @@ -166,8 +196,10 @@ class WorkflowAppGenerateTaskPipeline: } yield self._yield_response(response) - elif isinstance(event, QueueNodeFinishedEvent): - workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) + elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): + self._on_node_finished(event) + workflow_node_execution = self._task_state.current_node_execution + response = { 'event': 'node_finished', 'task_id': self._application_generate_entity.task_id, @@ -190,20 +222,9 @@ class WorkflowAppGenerateTaskPipeline: } yield self._yield_response(response) - elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): - if isinstance(event, QueueStopEvent): - workflow_run = self._get_workflow_run(self._task_state.workflow_run_id) - else: - workflow_run = self._get_workflow_run(event.workflow_run_id) - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: - outputs = workflow_run.outputs_dict - self._task_state.answer = outputs.get('text', '') - else: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - data = self._error_to_stream_response_data(self._handle_error(err_event)) - yield self._yield_response(data) - break + elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): + self._on_workflow_finished(event) + workflow_run = self._task_state.workflow_run # response moderation if self._output_moderation_handler: @@ -219,7 +240,7 @@ class WorkflowAppGenerateTaskPipeline: replace_response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'data': { 'text': self._task_state.answer } @@ -233,7 +254,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': event.workflow_run_id, + 'workflow_run_id': workflow_run.id, 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, @@ -244,7 +265,7 @@ class WorkflowAppGenerateTaskPipeline: 'total_tokens': workflow_run.total_tokens, 'total_steps': workflow_run.total_steps, 'created_at': int(workflow_run.created_at.timestamp()), - 'finished_at': int(workflow_run.finished_at.timestamp()) + 'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None } } @@ -279,7 +300,7 @@ class WorkflowAppGenerateTaskPipeline: response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'data': { 'text': event.text } @@ -291,6 +312,95 @@ class WorkflowAppGenerateTaskPipeline: else: continue + def _on_workflow_start(self) -> None: + self._task_state.start_at = time.perf_counter() + + workflow_run = self._init_workflow_run( + workflow=self._workflow, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN, + user=self._user, + user_inputs=self._application_generate_entity.inputs, + system_inputs={ + SystemVariable.FILES: self._application_generate_entity.files + } + ) + + self._task_state.workflow_run = workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + workflow_node_execution = self._init_node_execution_from_workflow_run( + workflow_run=self._task_state.workflow_run, + node_id=event.node_id, + node_type=event.node_type, + node_title=event.node_data.title, + node_run_index=event.node_run_index, + predecessor_node_id=event.predecessor_node_id + ) + + self._task_state.current_node_execution = workflow_node_execution + self._task_state.current_node_execution_start_at = time.perf_counter() + self._task_state.total_steps += 1 + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + if isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._workflow_node_execution_success( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + inputs=event.inputs, + process_data=event.process_data, + outputs=event.outputs, + execution_metadata=event.execution_metadata + ) + + if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + self._task_state.total_tokens += ( + int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) + else: + workflow_node_execution = self._workflow_node_execution_failed( + workflow_node_execution=self._task_state.current_node_execution, + start_at=self._task_state.current_node_execution_start_at, + error=event.error + ) + + self._task_state.current_node_execution = workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + if isinstance(event, QueueStopEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.' + ) + elif isinstance(event, QueueWorkflowFailedEvent): + workflow_run = self._workflow_run_failed( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + status=WorkflowRunStatus.FAILED, + error=event.error + ) + else: + workflow_run = self._workflow_run_success( + workflow_run=self._task_state.workflow_run, + start_at=self._task_state.start_at, + total_tokens=self._task_state.total_tokens, + total_steps=self._task_state.total_steps, + outputs=self._task_state.current_node_execution.outputs + if self._task_state.current_node_execution else None + ) + + self._task_state.workflow_run = workflow_run + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: + outputs = workflow_run.outputs_dict + self._task_state.answer = outputs.get('text', '') + def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Get workflow run. @@ -298,11 +408,6 @@ class WorkflowAppGenerateTaskPipeline: :return: """ workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - if workflow_run: - # Because the workflow_run will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_run) return workflow_run def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: @@ -313,11 +418,6 @@ class WorkflowAppGenerateTaskPipeline: """ workflow_node_execution = (db.session.query(WorkflowNodeExecution) .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - if workflow_node_execution: - # Because the workflow_node_execution will be modified in the sub-thread, - # and the first query in the main thread will cache the entity, - # you need to expire the entity after the query - db.session.expire(workflow_node_execution) return workflow_node_execution def _save_workflow_app_log(self) -> None: @@ -335,7 +435,7 @@ class WorkflowAppGenerateTaskPipeline: """ response = { 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run_id, + 'workflow_run_id': self._task_state.workflow_run.id, 'task_id': self._application_generate_entity.task_id, 'data': { 'text': text @@ -398,7 +498,6 @@ class WorkflowAppGenerateTaskPipeline: return { 'event': 'error', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run_id, **data } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 12b93518ed..318466711a 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -1,14 +1,19 @@ +from typing import Optional + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( - QueueNodeFinishedEvent, + QueueNodeFailedEvent, QueueNodeStartedEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, - QueueWorkflowFinishedEvent, + QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun +from models.workflow import Workflow class WorkflowEventTriggerCallback(BaseWorkflowCallback): @@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): self._queue_manager = queue_manager self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ self._queue_manager.publish( - QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowStartedEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_succeeded(self) -> None: """ - Workflow run finished + Workflow run succeeded """ self._queue_manager.publish( - QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + QueueWorkflowSucceededEvent(), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + self._queue_manager.publish( + QueueWorkflowFailedEvent( + error=error + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ self._queue_manager.publish( - QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeStartedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + node_run_index=node_run_index, + predecessor_node_id=predecessor_node_id + ), PublishFrom.APPLICATION_MANAGER ) - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: """ - Workflow node execute finished + Workflow node execute succeeded """ self._queue_manager.publish( - QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + QueueNodeSucceededEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata + ), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: + """ + Workflow node execute failed + """ + self._queue_manager.publish( + QueueNodeFailedEvent( + node_id=node_id, + node_type=node_type, + node_data=node_data, + error=error + ), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py new file mode 100644 index 0000000000..3e9a7b9e1f --- /dev/null +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -0,0 +1,202 @@ +import json +import time +from datetime import datetime +from typing import Optional, Union + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeType +from extensions.ext_database import db +from models.account import Account +from models.model import EndUser +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) + + +class WorkflowBasedGenerateTaskPipeline: + def _init_workflow_run(self, workflow: Workflow, + triggered_from: WorkflowRunTriggeredFrom, + user: Union[Account, EndUser], + user_inputs: dict, + system_inputs: Optional[dict] = None) -> WorkflowRun: + """ + Init workflow run + :param workflow: Workflow instance + :param triggered_from: triggered from + :param user: account or end user + :param user_inputs: user variables inputs + :param system_inputs: system inputs, like: query, files + :return: + """ + max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ + .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ + .filter(WorkflowRun.app_id == workflow.app_id) \ + .scalar() or 0 + new_sequence_number = max_sequence + 1 + + # init workflow run + workflow_run = WorkflowRun( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + sequence_number=new_sequence_number, + workflow_id=workflow.id, + type=workflow.type, + triggered_from=triggered_from.value, + version=workflow.version, + graph=workflow.graph, + inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), + status=WorkflowRunStatus.RUNNING.value, + created_by_role=(CreatedByRole.ACCOUNT.value + if isinstance(user, Account) else CreatedByRole.END_USER.value), + created_by=user.id + ) + + db.session.add(workflow_run) + db.session.commit() + + return workflow_run + + def _workflow_run_success(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Optional[dict] = None) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :return: + """ + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value + workflow_run.outputs = outputs + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_run + + def _workflow_run_failed(self, workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + status: WorkflowRunStatus, + error: str) -> WorkflowRun: + """ + Workflow run failed + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param status: status + :param error: error message + :return: + """ + workflow_run.status = status.value + workflow_run.error = error + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_run + + def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, + node_id: str, + node_type: NodeType, + node_title: str, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: + """ + Init workflow node execution from workflow run + :param workflow_run: workflow run + :param node_id: node id + :param node_type: node type + :param node_title: node title + :param node_run_index: run index + :param predecessor_node_id: predecessor node id if exists + :return: + """ + # init workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + workflow_run_id=workflow_run.id, + predecessor_node_id=predecessor_node_id, + index=node_run_index, + node_id=node_id, + node_type=node_type.value, + title=node_title, + status=WorkflowNodeExecutionStatus.RUNNING.value, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + """ + Workflow node execution success + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param inputs: inputs + :param process_data: process data + :param outputs: outputs + :param execution_metadata: execution metadata + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ + if execution_metadata else None + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_node_execution + + def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, + start_at: float, + error: str) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param workflow_node_execution: workflow node execution + :param start_at: start time + :param error: error message + :return: + """ + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.finished_at = datetime.utcnow() + + db.session.commit() + + return workflow_node_execution diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 67ed13d721..0ea7744b58 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,9 +1,11 @@ from enum import Enum -from typing import Any +from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType class QueueEvent(Enum): @@ -16,9 +18,11 @@ class QueueEvent(Enum): MESSAGE_REPLACE = "message_replace" MESSAGE_END = "message_end" WORKFLOW_STARTED = "workflow_started" - WORKFLOW_FINISHED = "workflow_finished" + WORKFLOW_SUCCEEDED = "workflow_succeeded" + WORKFLOW_FAILED = "workflow_failed" NODE_STARTED = "node_started" - NODE_FINISHED = "node_finished" + NODE_SUCCEEDED = "node_succeeded" + NODE_FAILED = "node_failed" RETRIEVER_RESOURCES = "retriever_resources" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" @@ -96,15 +100,21 @@ class QueueWorkflowStartedEvent(AppQueueEvent): QueueWorkflowStartedEvent entity """ event = QueueEvent.WORKFLOW_STARTED - workflow_run_id: str -class QueueWorkflowFinishedEvent(AppQueueEvent): +class QueueWorkflowSucceededEvent(AppQueueEvent): """ - QueueWorkflowFinishedEvent entity + QueueWorkflowSucceededEvent entity """ - event = QueueEvent.WORKFLOW_FINISHED - workflow_run_id: str + event = QueueEvent.WORKFLOW_SUCCEEDED + + +class QueueWorkflowFailedEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + event = QueueEvent.WORKFLOW_FAILED + error: str class QueueNodeStartedEvent(AppQueueEvent): @@ -112,17 +122,45 @@ class QueueNodeStartedEvent(AppQueueEvent): QueueNodeStartedEvent entity """ event = QueueEvent.NODE_STARTED - workflow_node_execution_id: str + + node_id: str + node_type: NodeType + node_data: BaseNodeData + node_run_index: int = 1 + predecessor_node_id: Optional[str] = None -class QueueNodeFinishedEvent(AppQueueEvent): +class QueueNodeSucceededEvent(AppQueueEvent): """ - QueueNodeFinishedEvent entity + QueueNodeSucceededEvent entity """ - event = QueueEvent.NODE_FINISHED - workflow_node_execution_id: str + event = QueueEvent.NODE_SUCCEEDED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + execution_metadata: Optional[dict] = None + + error: Optional[str] = None + + +class QueueNodeFailedEvent(AppQueueEvent): + """ + QueueNodeFailedEvent entity + """ + event = QueueEvent.NODE_FAILED + + node_id: str + node_type: NodeType + node_data: BaseNodeData + + error: str + - class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3866bf2c15..cf2915ed86 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,34 +1,63 @@ from abc import ABC, abstractmethod +from typing import Optional -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeType class BaseWorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_started(self) -> None: """ Workflow run started """ raise NotImplementedError @abstractmethod - def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + def on_workflow_run_succeeded(self) -> None: """ - Workflow run finished + Workflow run succeeded """ raise NotImplementedError @abstractmethod - def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_run_failed(self, error: str) -> None: + """ + Workflow run failed + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_started(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + node_run_index: int = 1, + predecessor_node_id: Optional[str] = None) -> None: """ Workflow node execute started """ raise NotImplementedError @abstractmethod - def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + def on_workflow_node_execute_succeeded(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None, + outputs: Optional[dict] = None, + execution_metadata: Optional[dict] = None) -> None: """ - Workflow node execute finished + Workflow node execute succeeded + """ + raise NotImplementedError + + @abstractmethod + def on_workflow_node_execute_failed(self, node_id: str, + node_type: NodeType, + node_data: BaseNodeData, + error: str) -> None: + """ + Workflow node execute failed """ raise NotImplementedError @@ -38,4 +67,3 @@ class BaseWorkflowCallback(ABC): Publish text chunk """ raise NotImplementedError - diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 8c15cb95cd..6c2adfe0fb 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -1,22 +1,32 @@ +from typing import Optional + +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.nodes.base_node import BaseNode +from models.workflow import Workflow + + +class WorkflowNodeAndResult: + node: BaseNode + result: Optional[NodeRunResult] = None + + def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): + self.node = node + self.result = result class WorkflowRunState: - workflow_run: WorkflowRun + workflow: Workflow start_at: float user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 - workflow_node_executions: list[WorkflowNodeExecution] = [] + workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow_run: WorkflowRun, - start_at: float, - user_inputs: dict, - variable_pool: VariablePool) -> None: - self.workflow_run = workflow_run + def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): + self.workflow = workflow self.start_at = start_at self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index bc6e4bd800..971cbe536e 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -43,7 +43,7 @@ class DirectAnswerNode(BaseNode): # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) + time.sleep(0.01) # todo sleep 0.01 return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 19dac76631..628df4ac5f 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,13 +1,11 @@ -import json import time -from datetime import datetime -from typing import Optional, Union +from typing import Optional -from core.model_runtime.utils.encoders import jsonable_encoder +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowRunState +from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -21,18 +19,9 @@ from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode -from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser from models.workflow import ( - CreatedByRole, Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, - WorkflowRunStatus, - WorkflowRunTriggeredFrom, WorkflowType, ) @@ -53,20 +42,6 @@ node_classes = { class WorkflowEngineManager: - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() - - # return workflow - return workflow - def get_default_configs(self) -> list[dict]: """ Get default block configs @@ -100,16 +75,12 @@ class WorkflowEngineManager: return default_config def run_workflow(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], user_inputs: dict, system_inputs: Optional[dict] = None, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks @@ -130,18 +101,13 @@ class WorkflowEngineManager: raise ValueError('edges in workflow graph must be a list') # init workflow run - workflow_run = self._init_workflow_run( - workflow=workflow, - triggered_from=triggered_from, - user=user, - user_inputs=user_inputs, - system_inputs=system_inputs, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() # init workflow run state workflow_run_state = WorkflowRunState( - workflow_run=workflow_run, + workflow=workflow, start_at=time.perf_counter(), user_inputs=user_inputs, variable_pool=VariablePool( @@ -166,7 +132,7 @@ class WorkflowEngineManager: has_entry_node = True # max steps 30 reached - if len(workflow_run_state.workflow_node_executions) > 30: + if len(workflow_run_state.workflow_nodes_and_results) > 30: raise ValueError('Max steps 30 reached.') # or max execution time 10min reached @@ -188,14 +154,14 @@ class WorkflowEngineManager: if not has_entry_node: self._workflow_run_failed( - workflow_run_state=workflow_run_state, error='Start node not found in workflow graph.', callbacks=callbacks ) return + except GenerateTaskStoppedException as e: + return except Exception as e: self._workflow_run_failed( - workflow_run_state=workflow_run_state, error=str(e), callbacks=callbacks ) @@ -203,112 +169,33 @@ class WorkflowEngineManager: # workflow run success self._workflow_run_success( - workflow_run_state=workflow_run_state, callbacks=callbacks ) - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 - new_sequence_number = max_sequence + 1 - - # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id - ) - - db.session.add(workflow_run) - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started(workflow_run) - - return workflow_run - - def _workflow_run_success(self, workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success - :param workflow_run_state: workflow run state :param callbacks: workflow callbacks :return: """ - workflow_run = workflow_run_state.workflow_run - workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - - # fetch last workflow_node_executions - last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] - if last_workflow_node_execution: - workflow_run.outputs = last_workflow_node_execution.outputs - - workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at - workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() if callbacks: for callback in callbacks: - callback.on_workflow_run_finished(workflow_run) + callback.on_workflow_run_succeeded() - return workflow_run - - def _workflow_run_failed(self, workflow_run_state: WorkflowRunState, - error: str, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun: + def _workflow_run_failed(self, error: str, + callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run failed - :param workflow_run_state: workflow run state :param error: error message :param callbacks: workflow callbacks :return: """ - workflow_run = workflow_run_state.workflow_run - workflow_run.status = WorkflowRunStatus.FAILED.value - workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at - workflow_run.total_tokens = workflow_run_state.total_tokens - workflow_run.total_steps = len(workflow_run_state.workflow_node_executions) - workflow_run.finished_at = datetime.utcnow() - - db.session.commit() - if callbacks: for callback in callbacks: - callback.on_workflow_run_finished(workflow_run) - - return workflow_run + callback.on_workflow_run_failed( + error=error + ) def _get_next_node(self, graph: dict, predecessor_node: Optional[BaseNode] = None, @@ -384,18 +271,24 @@ class WorkflowEngineManager: def _run_workflow_node(self, workflow_run_state: WorkflowRunState, node: BaseNode, predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - # init workflow node execution - start_at = time.perf_counter() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run_state=workflow_run_state, + callbacks: list[BaseWorkflowCallback] = None) -> None: + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_started( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None + ) + + workflow_nodes_and_result = WorkflowNodeAndResult( node=node, - predecessor_node=predecessor_node, - callbacks=callbacks + result=None ) - # add to workflow node executions - workflow_run_state.workflow_node_executions.append(workflow_node_execution) + # add to workflow_nodes_and_results + workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( @@ -406,24 +299,34 @@ class WorkflowEngineManager: if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=start_at, - error=node_run_result.error, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_failed( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + error=node_run_result.error + ) + raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") # set end node output if in chat self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) + workflow_nodes_and_result.result = node_run_result + # node run success - self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=start_at, - result=node_run_result, - callbacks=callbacks - ) + if callbacks: + for callback in callbacks: + callback.on_workflow_node_execute_succeeded( + node_id=node.node_id, + node_type=node.node_type, + node_data=node.node_data, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, + outputs=node_run_result.outputs, + execution_metadata=node_run_result.metadata + ) if node_run_result.outputs: for variable_key, variable_value in node_run_result.outputs.items(): @@ -438,105 +341,9 @@ class WorkflowEngineManager: if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - return workflow_node_execution - - def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run_state: workflow run state - :param node: current node - :param predecessor_node: predecessor node if exists - :param callbacks: workflow callbacks - :return: - """ - workflow_run = workflow_run_state.workflow_run - - # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - index=len(workflow_run_state.workflow_node_executions) + 1, - node_id=node.node_id, - node_type=node.node_type.value, - title=node.node_data.title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by - ) - - db.session.add(workflow_node_execution) - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started(workflow_node_execution) - - return workflow_node_execution - - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - result: NodeRunResult, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param result: node run result - :param callbacks: workflow callbacks - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None - workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None - workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \ - if result.metadata else None - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_finished(workflow_node_execution) - - return workflow_node_execution - - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.finished_at = datetime.utcnow() - - db.session.commit() - - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_finished(workflow_node_execution) - - return workflow_node_execution - def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, node: BaseNode, - node_run_result: NodeRunResult): + node_run_result: NodeRunResult) -> None: """ Set end node output if in chat :param workflow_run_state: workflow run state @@ -544,21 +351,19 @@ class WorkflowEngineManager: :param node_run_result: node run result :return: """ - if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: - workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2] - if workflow_node_execution_before_end: - if workflow_node_execution_before_end.node_type == NodeType.LLM.value: + if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] + if workflow_nodes_and_result_before_end: + if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM.value: if not node_run_result.outputs: node_run_result.outputs = {} - node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text') - elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value: + node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') + elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER.value: if not node_run_result.outputs: node_run_result.outputs = {} - node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer') - - return node_run_result + node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer') def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 833c22cdff..f8bd80a0b1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom @@ -44,10 +45,14 @@ class WorkflowService: if not app_model.workflow_id: return None - workflow_engine_manager = WorkflowEngineManager() - # fetch published workflow by workflow_id - return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id) + workflow = db.session.query(Workflow).filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id + ).first() + + return workflow def sync_draft_workflow(self, app_model: App, graph: dict, @@ -201,6 +206,14 @@ class WorkflowService: return response + def stop_workflow_task(self, task_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom) -> None: + """ + Stop workflow task + """ + AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From 9b0f83f807d908bc1c7c8ec61fd4c319e8f0f995 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 00:02:44 +0800 Subject: [PATCH 129/200] fix: add max number array length --- api/core/workflow/nodes/code/code_node.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 32f6776850..e7e8a1c251 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -13,6 +13,7 @@ MAX_PRECISION = 20 MAX_DEPTH = 5 MAX_STRING_LENGTH = 1000 MAX_STRING_ARRAY_LENGTH = 30 +MAX_NUMBER_ARRAY_LENGTH = 1000 class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -210,6 +211,11 @@ class CodeNode(BaseNode): f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: + raise ValueError( + f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' + ) + transformed_result[output_name] = [ self._check_number( value=value, From e90637f67a89042f0327fce5699b07b90768daa1 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 9 Mar 2024 00:58:12 +0800 Subject: [PATCH 130/200] fix generate bug --- api/core/app/apps/advanced_chat/app_generator.py | 4 ++-- api/core/app/apps/workflow/app_generator.py | 2 -- api/core/workflow/workflow_engine_manager.py | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index ed45e2ba8a..a0f197ec37 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -216,5 +216,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception(e) raise e - finally: - db.session.remove() + # finally: + # db.session.remove() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index d3303047ca..b1a70a83ba 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -168,5 +168,3 @@ class WorkflowAppGenerator(BaseAppGenerator): else: logger.exception(e) raise e - finally: - db.session.remove() diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 628df4ac5f..c5af015e87 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -354,12 +354,12 @@ class WorkflowEngineManager: if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] if workflow_nodes_and_result_before_end: - if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM.value: + if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: if not node_run_result.outputs: node_run_result.outputs = {} node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') - elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER.value: + elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER: if not node_run_result.outputs: node_run_result.outputs = {} From 4c5822fb6e2cad159793e85076a94313b1245ec0 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 15:51:02 +0800 Subject: [PATCH 131/200] fix: transform --- api/core/workflow/nodes/code/code_node.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index e7e8a1c251..77bcccab21 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -196,8 +196,6 @@ class CodeNode(BaseNode): value=result[output_name], variable=f'{prefix}.{output_name}' if prefix else output_name ) - - transformed_result[output_name] = result[output_name] elif output_config.type == 'string': # check if string available transformed_result[output_name] = self._check_string( From 2f57d090a1291087512f5f8ecc11c074fe2f71c5 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 9 Mar 2024 19:05:48 +0800 Subject: [PATCH 132/200] refactor pipeline and remove node run run_args --- .../advanced_chat/generate_task_pipeline.py | 47 ++++++++---- .../apps/workflow/generate_task_pipeline.py | 48 +++++++++---- api/core/workflow/entities/variable_pool.py | 5 +- .../workflow/entities/workflow_entities.py | 4 +- api/core/workflow/nodes/base_node.py | 34 ++++++--- api/core/workflow/nodes/code/code_node.py | 45 ++++++------ .../nodes/direct_answer/direct_answer_node.py | 21 +++--- api/core/workflow/nodes/end/end_node.py | 71 ++++++++++--------- api/core/workflow/nodes/llm/llm_node.py | 16 ++++- api/core/workflow/nodes/start/start_node.py | 18 +++-- api/core/workflow/workflow_engine_manager.py | 6 +- 11 files changed, 201 insertions(+), 114 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 18bc9c8008..048b429304 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -55,6 +55,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} usage: LLMUsage @@ -64,8 +77,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -218,7 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -237,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -447,15 +460,21 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -472,12 +491,14 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -504,8 +525,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 721124c4c5..26e4769fa6 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -41,6 +41,19 @@ class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): + """ + NodeExecutionInfo entity + """ + workflow_node_execution: WorkflowNodeExecution + start_at: float + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + answer: str = "" metadata: dict = {} @@ -49,8 +62,8 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - current_node_execution: Optional[WorkflowNodeExecution] = None - current_node_execution_start_at: Optional[float] = None + running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + latest_node_execution_info: Optional[NodeExecutionInfo] = None class Config: """Configuration for this pydantic object.""" @@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): self._on_node_start(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_started', @@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) - workflow_node_execution = self._task_state.current_node_execution + workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution response = { 'event': 'node_finished', @@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): predecessor_node_id=event.predecessor_node_id ) - self._task_state.current_node_execution = workflow_node_execution - self._task_state.current_node_execution_start_at = time.perf_counter() + latest_node_execution_info = TaskState.NodeExecutionInfo( + workflow_node_execution=workflow_node_execution, + start_at=time.perf_counter() + ) + + self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + current_node_execution = self._task_state.running_node_execution_infos[event.node_id] if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, outputs=event.outputs, @@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=self._task_state.current_node_execution, - start_at=self._task_state.current_node_execution_start_at, + workflow_node_execution=current_node_execution.workflow_node_execution, + start_at=current_node_execution.start_at, error=event.error ) - self._task_state.current_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] + self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: if isinstance(event, QueueStopEvent): @@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.current_node_execution.outputs - if self._task_state.current_node_execution else None + outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs + if self._task_state.latest_node_execution_info else None ) self._task_state.workflow_run = workflow_run diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index e84044dede..3868041a8f 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -19,14 +19,17 @@ class ValueType(Enum): class VariablePool: variables_mapping = {} + user_inputs: dict - def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: + def __init__(self, system_variables: dict[SystemVariable, Any], + user_inputs: dict) -> None: # system variables # for example: # { # 'query': 'abc', # 'files': [] # } + self.user_inputs = user_inputs for system_variable, value in system_variables.items(): self.append_variable('sys', [system_variable.value], value) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 6c2adfe0fb..768ad6a130 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -18,15 +18,13 @@ class WorkflowNodeAndResult: class WorkflowRunState: workflow: Workflow start_at: float - user_inputs: dict variable_pool: VariablePool total_tokens: int = 0 workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): self.workflow = workflow self.start_at = start_at - self.user_inputs = user_inputs self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6720017d9f..3f2e806433 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -28,31 +28,23 @@ class BaseNode(ABC): self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ raise NotImplementedError - def run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node entry :param variable_pool: variable pool - :param run_args: run args :return: """ - if variable_pool is None and run_args is None: - raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") - try: result = self._run( - variable_pool=variable_pool, - run_args=run_args + variable_pool=variable_pool ) except Exception as e: # process unhandled exception @@ -77,6 +69,26 @@ class BaseNode(ABC): text=text ) + @classmethod + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: + """ + Extract variable selector to variable mapping + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping(node_data) + + @classmethod + @abstractmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + raise NotImplementedError + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 77bcccab21..a65edafbad 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,6 @@ from typing import Optional, Union, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -15,6 +16,7 @@ MAX_STRING_LENGTH = 1000 MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 + class CodeNode(BaseNode): _node_data_cls = CodeNodeData node_type = NodeType.CODE @@ -78,21 +80,15 @@ class CodeNode(BaseNode): } } - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run code :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(self._node_data_cls, node_data) - # SINGLE DEBUG NOT IMPLEMENTED YET - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - # Get code language code_language = node_data.code_language code = node_data.code @@ -134,7 +130,6 @@ class CodeNode(BaseNode): Check string :param value: value :param variable: variable - :param max_length: max length :return: """ if not isinstance(value, str): @@ -142,9 +137,9 @@ class CodeNode(BaseNode): if len(value) > MAX_STRING_LENGTH: raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') - + return value.replace('\x00', '') - + def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ Check number @@ -157,13 +152,13 @@ class CodeNode(BaseNode): if value > MAX_NUMBER or value < MIN_NUMBER: raise ValueError(f'{variable} in input form is out of range.') - + if isinstance(value, float): value = round(value, MAX_PRECISION) return value - def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], prefix: str = '', depth: int = 1) -> dict: """ @@ -174,7 +169,7 @@ class CodeNode(BaseNode): """ if depth > MAX_DEPTH: raise ValueError("Depth limit reached, object too deep.") - + transformed_result = {} for output_name, output_config in output_schema.items(): if output_config.type == 'object': @@ -183,7 +178,7 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' ) - + transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, @@ -208,7 +203,7 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' @@ -227,12 +222,12 @@ class CodeNode(BaseNode): raise ValueError( f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' ) - + if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: raise ValueError( f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' ) - + transformed_result[output_name] = [ self._check_string( value=value, @@ -242,5 +237,15 @@ class CodeNode(BaseNode): ] else: raise ValueError(f'Output type {output_config.type} is not supported.') - - return transformed_result \ No newline at end of file + + return transformed_result + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 971cbe536e..9193bab9ee 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -1,7 +1,8 @@ import time -from typing import Optional, cast +from typing import cast from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode): _node_data_cls = DirectAnswerNodeData node_type = NodeType.DIRECT_ANSWER - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - if variable_pool is None and run_args: - raise ValueError("Not support single step debug.") - variable_values = {} for variable_selector in node_data.variables: value = variable_pool.get_variable_value( @@ -43,7 +39,7 @@ class DirectAnswerNode(BaseNode): # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) # todo sleep 0.01 + time.sleep(0.01) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode): "answer": answer } ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 62429e3ac2..65b0b86aa0 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,6 @@ -from typing import Optional, cast +from typing import cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.base_node import BaseNode @@ -11,50 +12,54 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData node_type = NodeType.END - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) outputs_config = node_data.outputs - if variable_pool is not None: - outputs = None - if outputs_config: - if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: - plain_text_selector = outputs_config.plain_text_selector - if plain_text_selector: - outputs = { - 'text': variable_pool.get_variable_value( - variable_selector=plain_text_selector, - target_value_type=ValueType.STRING - ) - } - else: - outputs = { - 'text': '' - } - elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: - structured_variables = outputs_config.structured_variables - if structured_variables: - outputs = {} - for variable_selector in structured_variables: - variable_value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - outputs[variable_selector.variable] = variable_value - else: - outputs = {} - else: - raise ValueError("Not support single step debug.") + outputs = None + if outputs_config: + if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: + plain_text_selector = outputs_config.plain_text_selector + if plain_text_selector: + outputs = { + 'text': variable_pool.get_variable_value( + variable_selector=plain_text_selector, + target_value_type=ValueType.STRING + ) + } + else: + outputs = { + 'text': '' + } + elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: + structured_variables = outputs_config.structured_variables + if structured_variables: + outputs = {} + for variable_selector in structured_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value + else: + outputs = {} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index e3ae9fc00f..90a7755b85 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,5 +1,6 @@ from typing import Optional, cast +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -10,12 +11,10 @@ class LLMNode(BaseNode): _node_data_cls = LLMNodeData node_type = NodeType.LLM - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -23,6 +22,17 @@ class LLMNode(BaseNode): pass + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + # TODO extract variable selector to variable mapping for single step debugging + return {} + + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index ce04031b04..2321e04bd4 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,7 @@ -from typing import Optional, cast +from typing import cast from core.app.app_config.entities import VariableEntity +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -12,12 +13,10 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData node_type = NodeType.START - def _run(self, variable_pool: Optional[VariablePool] = None, - run_args: Optional[dict] = None) -> NodeRunResult: + def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ Run node :param variable_pool: variable pool - :param run_args: run args :return: """ node_data = self.node_data @@ -25,7 +24,7 @@ class StartNode(BaseNode): variables = node_data.variables # Get cleaned inputs - cleaned_inputs = self._get_cleaned_inputs(variables, run_args) + cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -68,3 +67,12 @@ class StartNode(BaseNode): filtered_inputs[variable] = value.replace('\x00', '') if value else None return filtered_inputs + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index c5af015e87..0b96717de7 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -109,9 +109,9 @@ class WorkflowEngineManager: workflow_run_state = WorkflowRunState( workflow=workflow, start_at=time.perf_counter(), - user_inputs=user_inputs, variable_pool=VariablePool( system_variables=system_inputs, + user_inputs=user_inputs ) ) @@ -292,9 +292,7 @@ class WorkflowEngineManager: # run node, result must have inputs, process_data, outputs, execution_metadata node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool, - run_args=workflow_run_state.user_inputs - if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node + variable_pool=workflow_run_state.variable_pool ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: From a0fd731170a1fc6a890d7a9618b6c25e164b72c4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 19:45:57 +0800 Subject: [PATCH 133/200] feat: mapping variables --- api/core/workflow/nodes/code/code_node.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index a65edafbad..170f2b9cd8 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -87,7 +87,7 @@ class CodeNode(BaseNode): :return: """ node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data: CodeNodeData = cast(self._node_data_cls, node_data) # Get code language code_language = node_data.code_language @@ -241,11 +241,13 @@ class CodeNode(BaseNode): return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[list[str], str]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - # TODO extract variable selector to variable mapping for single step debugging - return {} + + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 193bcce236176abc939693f80f3584d2fb1f36eb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 19:59:47 +0800 Subject: [PATCH 134/200] feat: http request --- api/core/workflow/nodes/code/code_node.py | 1 - .../workflow/nodes/http_request/entities.py | 31 +++++++++++++++++++ .../nodes/http_request/http_request_node.py | 20 ++++++++++-- 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/http_request/entities.py diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 170f2b9cd8..3d3c475d06 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,5 @@ from typing import Optional, Union, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py new file mode 100644 index 0000000000..8610e88e55 --- /dev/null +++ b/api/core/workflow/nodes/http_request/entities.py @@ -0,0 +1,31 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class HttpRequestNodeData(BaseNodeData): + """ + Code Node Data. + """ + class Authorization(BaseModel): + class Config(BaseModel): + type: Literal[None, 'basic', 'bearer', 'custom'] + api_key: Union[None, str] + header: Union[None, str] + + type: Literal['no-auth', 'api-key'] + + class Body(BaseModel): + type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw'] + data: Union[None, str] + + variables: list[VariableSelector] + method: Literal['get', 'post', 'put', 'patch', 'delete'] + url: str + authorization: Authorization + headers: str + params: str + \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 5be25a9834..d0fa29646f 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,5 +1,21 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode - +from core.workflow.nodes.http_request.entities import HttpRequestNodeData class HttpRequestNode(BaseNode): - pass + _node_data_cls = HttpRequestNodeData + node_type = NodeType.HTTP_REQUEST + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + pass \ No newline at end of file From 614bc2e075eee1ab938e363a9168d776002e4dc4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:19:48 +0800 Subject: [PATCH 135/200] feat: http reqeust --- api/core/helper/ssrf_proxy.py | 4 + .../workflow/nodes/http_request/entities.py | 5 +- .../nodes/http_request/http_executor.py | 240 ++++++++++++++++++ .../nodes/http_request/http_request_node.py | 39 ++- 4 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/http_request/http_executor.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0bfe763fac..c44d4717e6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -38,6 +38,10 @@ def patch(url, *args, **kwargs): return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) def delete(url, *args, **kwargs): + if 'follow_redirects' in kwargs: + if kwargs['follow_redirects']: + kwargs['allow_redirects'] = kwargs['follow_redirects'] + kwargs.pop('follow_redirects') return _delete(url=url, *args, proxies=requests_proxies, **kwargs) def head(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8610e88e55..1e906cbaa4 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -17,9 +17,10 @@ class HttpRequestNodeData(BaseNodeData): header: Union[None, str] type: Literal['no-auth', 'api-key'] + config: Config class Body(BaseModel): - type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw'] + type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] data: Union[None, str] variables: list[VariableSelector] @@ -28,4 +29,4 @@ class HttpRequestNodeData(BaseNodeData): authorization: Authorization headers: str params: str - \ No newline at end of file + body: Body \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py new file mode 100644 index 0000000000..4b13e92e0c --- /dev/null +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -0,0 +1,240 @@ +from copy import deepcopy +from typing import Any, Union +from urllib.parse import urlencode + +import httpx +import re +import requests +import core.helper.ssrf_proxy as ssrf_proxy +from core.workflow.nodes.http_request.entities import HttpRequestNodeData + +HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60) + +class HttpExecutorResponse: + status_code: int + headers: dict[str, str] + body: str + + def __init__(self, status_code: int, headers: dict[str, str], body: str): + """ + init + """ + self.status_code = status_code + self.headers = headers + self.body = body + +class HttpExecutor: + server_url: str + method: str + authorization: HttpRequestNodeData.Authorization + params: dict[str, Any] + headers: dict[str, Any] + body: Union[None, str] + files: Union[None, dict[str, Any]] + + def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init + """ + self.server_url = node_data.url + self.method = node_data.method + self.authorization = node_data.authorization + self.params = {} + self.headers = {} + self.body = None + + # init template + self._init_template(node_data, variables) + + def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + """ + init template + """ + # extract all template in url + url_template = re.findall(r'{{(.*?)}}', node_data.url) or [] + url_template = list(set(url_template)) + original_url = node_data.url + for url in url_template: + if not url: + continue + + original_url = original_url.replace(f'{{{{{url}}}}}', str(variables.get(url, ''))) + + self.server_url = original_url + + # extract all template in params + param_template = re.findall(r'{{(.*?)}}', node_data.params) or [] + param_template = list(set(param_template)) + original_params = node_data.params + for param in param_template: + if not param: + continue + + original_params = original_params.replace(f'{{{{{param}}}}}', str(variables.get(param, ''))) + + # fill in params + kv_paris = original_params.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid params {kv}') + + k, v = kv + self.params[k] = v + + # extract all template in headers + header_template = re.findall(r'{{(.*?)}}', node_data.headers) or [] + header_template = list(set(header_template)) + original_headers = node_data.headers + for header in header_template: + if not header: + continue + + original_headers = original_headers.replace(f'{{{{{header}}}}}', str(variables.get(header, ''))) + + # fill in headers + kv_paris = original_headers.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid headers {kv}') + + k, v = kv + self.headers[k] = v + + # extract all template in body + body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] + body_template = list(set(body_template)) + original_body = node_data.body.data or '' + for body in body_template: + if not body: + continue + + original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) + + if node_data.body.type == 'json': + self.headers['Content-Type'] = 'application/json' + elif node_data.body.type == 'x-www-form-urlencoded': + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + # elif node_data.body.type == 'form-data': + # self.headers['Content-Type'] = 'multipart/form-data' + + if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + body = {} + kv_paris = original_body.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) != 2: + raise ValueError(f'Invalid body {kv}') + body[kv[0]] = kv[1] + + if node_data.body.type == 'form-data': + self.files = { + k: ('', v) for k, v in body.items() + } + else: + self.body = urlencode(body) + else: + self.body = original_body + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.authorization) + headers = deepcopy(self.headers) or [] + if self.authorization.type == 'api-key': + if self.authorization.config.api_key is None: + raise ValueError('api_key is required') + + if not self.authorization.config.header: + authorization.config.header = 'Authorization' + + if self.authorization.config.type == 'bearer': + headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' + elif self.authorization.config.type == 'basic': + headers[authorization.config.header] = f'Basic {authorization.config.api_key}' + elif self.authorization.config.type == 'custom': + headers[authorization.config.header] = authorization.config.api_key + + return headers + + def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: + """ + validate the response + """ + if isinstance(response, httpx.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + elif isinstance(response, requests.Response): + # get key-value pairs headers + headers = {} + for k, v in response.headers.items(): + headers[k] = v + + return HttpExecutorResponse(response.status_code, headers, response.text) + else: + raise ValueError(f'Invalid response type {type(response)}') + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + # do http request + kwargs = { + 'url': self.server_url, + 'headers': headers, + 'params': self.params, + 'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT, + 'follow_redirects': True + } + + if self.method == 'get': + response = ssrf_proxy.get(**kwargs) + elif self.method == 'post': + response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs) + elif self.method == 'put': + response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs) + elif self.method == 'delete': + response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs) + elif self.method == 'patch': + response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs) + elif self.method == 'head': + response = ssrf_proxy.head(**kwargs) + elif self.method == 'options': + response = ssrf_proxy.options(**kwargs) + else: + raise ValueError(f'Invalid http method {self.method}') + + return response + + def invoke(self) -> HttpExecutorResponse: + """ + invoke http request + """ + # assemble headers + headers = self._assembling_headers() + + # do http request + response = self._do_http_request(headers) + + # validate response + return self._validate_and_parse_response(response) + + def to_raw_request(self) -> str: + """ + convert to raw request + """ + server_url = self.server_url + if self.params: + server_url += f'?{urlencode(self.params)}' + + raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + for k, v in self.headers.items(): + raw_request += f'{k}: {v}\n' + + raw_request += '\n' + raw_request += self.body or '' + + return raw_request \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index d0fa29646f..f55f48c4af 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,15 +1,52 @@ +from os import error +from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.http_executor import HttpExecutor +from models.workflow import WorkflowNodeExecutionStatus + class HttpRequestNode(BaseNode): _node_data_cls = HttpRequestNodeData node_type = NodeType.HTTP_REQUEST def _run(self, variable_pool: VariablePool) -> NodeRunResult: - pass + node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value(variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # init http executor + try: + http_executor = HttpExecutor(node_data=node_data, variables=variables) + # invoke http executor + + response = http_executor.invoke() + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + process_data=http_executor.to_raw_request() + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs={ + 'status_code': response.status_code, + 'body': response, + 'headers': response.headers + }, + process_data=http_executor.to_raw_request() + ) + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: From 3d5f9b5a1eb9d2921339ac3fe96b0dd6426170af Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:26:19 +0800 Subject: [PATCH 136/200] fix: missing _extract_variable_selector_to_variable_mapping --- api/core/workflow/nodes/http_request/http_executor.py | 3 ++- api/core/workflow/nodes/http_request/http_request_node.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 4b13e92e0c..82d879a89c 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -1,10 +1,11 @@ +import re from copy import deepcopy from typing import Any, Union from urllib.parse import urlencode import httpx -import re import requests + import core.helper.ssrf_proxy as ssrf_proxy from core.workflow.nodes.http_request.entities import HttpRequestNodeData diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index f55f48c4af..e3e864b6b0 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,5 +1,5 @@ -from os import error from typing import cast + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -49,10 +49,12 @@ class HttpRequestNode(BaseNode): @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[list[str], str]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - pass \ No newline at end of file + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 2895c3bc8c997efcaef70f6008917e38c4366d22 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:49:53 +0800 Subject: [PATCH 137/200] feat: template transform --- .../code_executor}/code_executor.py | 15 +++-- .../code_executor/javascript_transformer.py | 1 + .../helper/code_executor/jina2_transformer.py | 1 + .../code_executor/python_transformer.py} | 4 +- .../code_executor/template_transformer.py | 24 ++++++++ api/core/workflow/nodes/code/code_node.py | 2 +- api/core/workflow/nodes/code/entities.py | 2 +- .../nodes/http_request/http_request_node.py | 1 - .../nodes/template_transform/entities.py | 14 +++++ .../template_transform_node.py | 59 ++++++++++++++++++- 10 files changed, 114 insertions(+), 9 deletions(-) rename api/core/{workflow/nodes/code => helper/code_executor}/code_executor.py (75%) create mode 100644 api/core/helper/code_executor/javascript_transformer.py create mode 100644 api/core/helper/code_executor/jina2_transformer.py rename api/core/{workflow/nodes/code/python_template.py => helper/code_executor/python_transformer.py} (90%) create mode 100644 api/core/helper/code_executor/template_transformer.py create mode 100644 api/core/workflow/nodes/template_transform/entities.py diff --git a/api/core/workflow/nodes/code/code_executor.py b/api/core/helper/code_executor/code_executor.py similarity index 75% rename from api/core/workflow/nodes/code/code_executor.py rename to api/core/helper/code_executor/code_executor.py index 058ee83d46..f1bc4fbdaf 100644 --- a/api/core/workflow/nodes/code/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,10 +1,11 @@ from os import environ +from typing import Literal from httpx import post from pydantic import BaseModel from yarl import URL -from core.workflow.nodes.code.python_template import PythonTemplateTransformer +from core.helper.code_executor.python_transformer import PythonTemplateTransformer # Code Executor CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') @@ -24,7 +25,7 @@ class CodeExecutionResponse(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: str, code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -32,7 +33,13 @@ class CodeExecutor: :param inputs: inputs :return: """ - runner = PythonTemplateTransformer.transform_caller(code, inputs) + template_transformer = None + if language == 'python3': + template_transformer = PythonTemplateTransformer + else: + raise CodeExecutionException('Unsupported language') + + runner = template_transformer.transform_caller(code, inputs) url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' headers = { @@ -67,4 +74,4 @@ class CodeExecutor: if response.data.stderr: raise CodeExecutionException(response.data.stderr) - return PythonTemplateTransformer.transform_response(response.data.stdout) \ No newline at end of file + return template_transformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py new file mode 100644 index 0000000000..f87f5c14cb --- /dev/null +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py new file mode 100644 index 0000000000..f87f5c14cb --- /dev/null +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/api/core/workflow/nodes/code/python_template.py b/api/core/helper/code_executor/python_transformer.py similarity index 90% rename from api/core/workflow/nodes/code/python_template.py rename to api/core/helper/code_executor/python_transformer.py index 03dfee36f3..7b862649d8 100644 --- a/api/core/workflow/nodes/code/python_template.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -1,6 +1,8 @@ import json import re +from core.helper.code_executor.template_transformer import TemplateTransformer + PYTHON_RUNNER = """# declare main function here {{code}} @@ -19,7 +21,7 @@ print(result) """ -class PythonTemplateTransformer: +class PythonTemplateTransformer(TemplateTransformer): @classmethod def transform_caller(cls, code: str, inputs: dict) -> str: """ diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py new file mode 100644 index 0000000000..5505df8749 --- /dev/null +++ b/api/core/helper/code_executor/template_transformer.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + + +class TemplateTransformer(ABC): + @classmethod + @abstractmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + pass + + @classmethod + @abstractmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + pass \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 3d3c475d06..7d3162d983 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,9 +1,9 @@ from typing import Optional, Union, cast +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.code.code_executor import CodeExecutionException, CodeExecutor from core.workflow.nodes.code.entities import CodeNodeData from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 2212d77e2d..6a18d181cb 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -16,6 +16,6 @@ class CodeNodeData(BaseNodeData): variables: list[VariableSelector] answer: str - code_language: str + code_language: Literal['python3', 'javascript'] code: str outputs: dict[str, Output] diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index e3e864b6b0..4ee76deb83 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,6 +1,5 @@ from typing import cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py new file mode 100644 index 0000000000..2d3d35b84c --- /dev/null +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -0,0 +1,14 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class TemplateTransformNodeData(BaseNodeData): + """ + Code Node Data. + """ + variables: list[VariableSelector] + template: str \ No newline at end of file diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2bf26e307e..3fb880d926 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,9 +1,18 @@ -from typing import Optional +from typing import Optional, cast +from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from models.workflow import WorkflowNodeExecutionStatus class TemplateTransformNode(BaseNode): + _node_data_cls = TemplateTransformNodeData + _node_type = NodeType.TEMPLATE_TRANSFORM + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -23,3 +32,51 @@ class TemplateTransformNode(BaseNode): "template": "{{ arg1 }}" } } + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + """ + node_data = self.node_data + node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) + + # Get variables + variables = {} + for variable_selector in node_data.variables: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + variables[variable] = value + + # Run code + try: + result = CodeExecutor.execute_code( + language='jina2', + code=node_data.template, + inputs=variables + ) + except CodeExecutionException as e: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + outputs=result['result'] + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return { + variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + } \ No newline at end of file From 51f6ab49cf15bc1edbaa68c29288057cda5c1a99 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sat, 9 Mar 2024 22:50:11 +0800 Subject: [PATCH 138/200] fix: linter --- api/core/workflow/nodes/template_transform/entities.py | 2 -- .../nodes/template_transform/template_transform_node.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index 2d3d35b84c..d9099a8118 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,6 +1,4 @@ -from typing import Literal, Union -from pydantic import BaseModel from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 3fb880d926..724b84495c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,9 +1,8 @@ from typing import Optional, cast + from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool - from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus From de3978fdbb7a0b41883afd493af4abee718f651f Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 13:19:17 +0800 Subject: [PATCH 139/200] optimize db connections --- api/config.py | 2 ++ api/core/app/apps/advanced_chat/app_generator.py | 13 ++++++++++--- .../apps/advanced_chat/generate_task_pipeline.py | 2 ++ api/core/app/apps/message_based_app_generator.py | 8 ++++++++ .../app/apps/workflow/generate_task_pipeline.py | 2 ++ .../apps/workflow_based_generate_task_pipeline.py | 11 +++++++++++ api/core/workflow/workflow_engine_manager.py | 5 +++++ 7 files changed, 40 insertions(+), 3 deletions(-) diff --git a/api/config.py b/api/config.py index a6bc731b82..a4ec6fcef9 100644 --- a/api/config.py +++ b/api/config.py @@ -27,6 +27,7 @@ DEFAULTS = { 'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'DEPLOY_ENV': 'PRODUCTION', 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_MAX_OVERFLOW': 10, 'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, @@ -148,6 +149,7 @@ class Config: self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}" self.SQLALCHEMY_ENGINE_OPTIONS = { 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), + 'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')), 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) } diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a0f197ec37..50b561dfe6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -95,6 +95,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras ) + workflow = db.session.query(Workflow).filter(Workflow.id == workflow.id).first() + user = (db.session.query(Account).filter(Account.id == user.id).first() + if isinstance(user, Account) + else db.session.query(EndUser).filter(EndUser.id == user.id).first()) + db.session.close() + # init generate records ( conversation, @@ -153,6 +159,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + db.session.close() + # chatbot app runner = AdvancedChatAppRunner() runner.run( @@ -177,7 +185,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, @@ -198,6 +206,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :return: """ # init generate task pipeline + generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, workflow=workflow, @@ -216,5 +225,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception(e) raise e - # finally: - # db.session.remove() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 048b429304..6991b8704a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -122,6 +122,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 0e76c96ff7..be7538ea07 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -177,6 +177,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(conversation) db.session.commit() + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + db.session.close() + message = Message( app_id=app_config.app_id, model_provider=model_provider, @@ -204,6 +207,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message) db.session.commit() + message = db.session.query(Message).filter(Message.id == message.id).first() + db.session.close() + for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, @@ -218,6 +224,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message_file) db.session.commit() + db.session.close() + return conversation, message def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 26e4769fa6..2c2f941bee 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -99,6 +99,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream + db.session.close() + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index 3e9a7b9e1f..640159bae3 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -61,6 +61,9 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_run) db.session.commit() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run.id).first() + db.session.close() + return workflow_run def _workflow_run_success(self, workflow_run: WorkflowRun, @@ -85,6 +88,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -112,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_run @@ -151,6 +156,10 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_node_execution) db.session.commit() + workflow_node_execution = (db.session.query(WorkflowNodeExecution) + .filter(WorkflowNodeExecution.id == workflow_node_execution.id).first()) + db.session.close() + return workflow_node_execution def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, @@ -179,6 +188,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution @@ -198,5 +208,6 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.close() return workflow_node_execution diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0b96717de7..50f79df1f0 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -19,6 +19,7 @@ from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode +from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, @@ -282,6 +283,8 @@ class WorkflowEngineManager: predecessor_node_id=predecessor_node.node_id if predecessor_node else None ) + db.session.close() + workflow_nodes_and_result = WorkflowNodeAndResult( node=node, result=None @@ -339,6 +342,8 @@ class WorkflowEngineManager: if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) + db.session.close() + def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, node: BaseNode, node_run_result: NodeRunResult) -> None: From 7e4daf131e7da3ab7eb081020edc01260f0d97b6 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 14:49:52 +0800 Subject: [PATCH 140/200] optimize db connections --- api/core/app/apps/advanced_chat/app_generator.py | 7 ------- .../app/apps/advanced_chat/generate_task_pipeline.py | 6 ++++-- api/core/app/apps/message_based_app_generator.py | 10 ++-------- api/core/app/apps/workflow/generate_task_pipeline.py | 6 ++++-- .../app/apps/workflow_based_generate_task_pipeline.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 50b561dfe6..b1bc839966 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -95,12 +95,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras ) - workflow = db.session.query(Workflow).filter(Workflow.id == workflow.id).first() - user = (db.session.query(Account).filter(Account.id == user.id).first() - if isinstance(user, Account) - else db.session.query(EndUser).filter(EndUser.id == user.id).first()) - db.session.close() - # init generate records ( conversation, @@ -206,7 +200,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :return: """ # init generate task pipeline - generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, workflow=workflow, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6991b8704a..88ac5fd235 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -122,13 +122,15 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream - db.session.close() - def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + if self._stream: return self._process_stream_response() else: diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index be7538ea07..5d0f4bc63a 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -176,9 +176,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(conversation) db.session.commit() - - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - db.session.close() + db.session.refresh(conversation) message = Message( app_id=app_config.app_id, @@ -206,9 +204,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message) db.session.commit() - - message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.close() + db.session.refresh(message) for file in application_generate_entity.files: message_file = MessageFile( @@ -224,8 +220,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(message_file) db.session.commit() - db.session.close() - return conversation, message def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2c2f941bee..9bd20f9785 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -99,13 +99,15 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream - db.session.close() - def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. :return: """ + db.session.refresh(self._workflow) + db.session.refresh(self._user) + db.session.close() + if self._stream: return self._process_stream_response() else: diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index 640159bae3..d29cee3ac4 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -60,8 +60,7 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_run) db.session.commit() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run.id).first() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -155,9 +154,7 @@ class WorkflowBasedGenerateTaskPipeline: db.session.add(workflow_node_execution) db.session.commit() - - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution.id).first()) + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution From 8b832097de7316238a0713c05eca839a468863b0 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 16:29:55 +0800 Subject: [PATCH 141/200] optimize db connections --- api/controllers/console/app/app.py | 72 +++++---- api/controllers/console/app/model_config.py | 145 +++++++++--------- .../easy_ui_based_app/dataset/manager.py | 3 +- .../app/apps/advanced_chat/app_generator.py | 2 - api/core/app/apps/advanced_chat/app_runner.py | 2 + api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 4 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 + .../app/apps/message_based_app_generator.py | 2 - api/core/app/apps/workflow/app_runner.py | 2 + api/core/tools/tool_manager.py | 2 +- api/models/model.py | 2 +- 14 files changed, 126 insertions(+), 118 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 66bcbccefe..9440603069 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,5 @@ +import json + from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import Forbidden, BadRequest @@ -6,6 +8,8 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from core.agent.entities import AgentToolEntity +from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, @@ -14,10 +18,8 @@ from fields.app_fields import ( from libs.login import login_required from services.app_service import AppService from models.model import App, AppModelConfig, AppMode -from services.workflow_service import WorkflowService from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.tool_manager import ToolManager -from core.entities.application_entities import AgentToolEntity ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] @@ -108,41 +110,43 @@ class AppApi(Resource): def get(self, app_model): """Get app detail""" # get original app model config - model_config: AppModelConfig = app_model.app_model_config - agent_mode = model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + model_config: AppModelConfig = app_model.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) + # get tool + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} - # override tool parameters - tool['tool_parameters'] = masked_parameter - except Exception as e: - pass + # override tool parameters + tool['tool_parameters'] = masked_parameter + except Exception as e: + pass - # override agent mode - model_config.agent_mode = json.dumps(agent_mode) + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + db.session.commit() return app_model diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 1301d12da4..41b7151ba6 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -8,7 +8,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.entities.application_entities import AgentToolEntity +from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated @@ -38,90 +38,91 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app.app_model_config_id - ).first() - agent_mode = original_app_model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - parameter_map = {} - masked_parameter_map = {} - tool_map = {} - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - agent_callback=None - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - except Exception as e: - continue + if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app_model.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - parameters = {} - masked_parameter = {} - - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - masked_parameter_map[key] = masked_parameter - parameter_map[key] = parameters - tool_map[key] = tool_runtime - - # encrypt agent tool parameters if it's secret-input - agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: - agent_tool_entity = AgentToolEntity(**tool) - - # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' - if key in tool_map: - tool_runtime = tool_map[key] - else: + agent_tool_entity = AgentToolEntity(**tool) + # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, agent_callback=None ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) except Exception as e: continue - - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - manager.delete_tool_parameters_cache() - # override parameters if it equals to masked parameters - if agent_tool_entity.tool_parameters: - if key not in masked_parameter_map: - continue + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} - if agent_tool_entity.tool_parameters == masked_parameter_map[key]: - agent_tool_entity.tool_parameters = parameter_map[key] + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime - # encrypt parameters - if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) - # update app model config - new_app_model_config.agent_mode = json.dumps(agent_mode) + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + except Exception as e: + continue + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] + + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) db.session.add(new_app_model_config) db.session.flush() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 4c08f62d27..c10aa98dba 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -123,7 +123,8 @@ class DatasetConfigManager: if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") + need_manual_query_datasets = (config.get("dataset_configs") + and config["dataset_configs"].get("datasets", {}).get("datasets")) if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index b1bc839966..1a33a3230b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -153,8 +153,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) - db.session.close() - # chatbot app runner = AdvancedChatAppRunner() runner.run( diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 3279e00355..c42620b92f 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -72,6 +72,8 @@ class AdvancedChatAppRunner(AppRunner): ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 700a340c96..cc9b0785f5 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2e142c63f1..0dc8a1e218 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -201,8 +201,8 @@ class AgentChatAppRunner(AppRunner): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - db.session.refresh(conversation) - db.session.refresh(message) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + message = db.session.query(Message).filter(Message.id == message.id).first() db.session.close() # start agent runner diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 317d045c04..58287ba658 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index b948938aac..fb62469720 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def generate_more_like_this(self, app_model: App, message_id: str, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 04adf77be5..649d73d961 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner): model=application_generate_entity.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=application_generate_entity.model_config.parameters, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 5d0f4bc63a..5e676c40bd 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -64,8 +64,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): else: logger.exception(e) raise e - finally: - db.session.remove() def _get_conversation_by_user(self, app_model: App, conversation_id: str, user: Union[Account, EndUser]) -> Conversation: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 59a385cb38..2d032fcdcb 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -57,6 +57,8 @@ class WorkflowAppRunner: ): return + db.session.close() + # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab..24b2f287c1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,8 +5,8 @@ import mimetypes from os import listdir, path from typing import Any, Union +from core.agent.entities import AgentToolEntity from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject diff --git a/api/models/model.py b/api/models/model.py index 6856c4e1b0..5a7311a0c7 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -322,7 +322,7 @@ class AppModelConfig(db.Model): } def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config['opening_statement'] + self.opening_statement = model_config.get('opening_statement') self.suggested_questions = json.dumps(model_config['suggested_questions']) \ if model_config.get('suggested_questions') else None self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ From 61a1aadf9ca04c09008daf9d6914d2c60ada7c42 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 16:59:17 +0800 Subject: [PATCH 142/200] optimize workflow db connections --- .../advanced_chat/generate_task_pipeline.py | 99 ++++++++++--------- .../apps/workflow/generate_task_pipeline.py | 98 +++++++++--------- .../workflow_based_generate_task_pipeline.py | 4 + 3 files changed, 105 insertions(+), 96 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 88ac5fd235..d5d3feded0 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -59,7 +59,7 @@ class TaskState(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -72,7 +72,7 @@ class TaskState(BaseModel): metadata: dict = {} usage: LLMUsage - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -168,8 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) @@ -218,8 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -234,8 +232,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -253,8 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.node_type == NodeType.LLM.value: @@ -285,8 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) @@ -435,7 +430,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -452,11 +447,16 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -465,19 +465,26 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info + self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -495,19 +502,24 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.metadata['usage'] = usage_dict else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution + # remove running node execution info + del self._task_state.running_node_execution_infos[event.node_id] - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -516,7 +528,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -524,39 +536,30 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_message(self) -> None: """ @@ -567,7 +570,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.workflow_run_id = self._task_state.workflow_run.id + self._message.workflow_run_id = self._task_state.workflow_run_id if self._task_state.metadata and self._task_state.metadata.get('usage'): usage = LLMUsage(**self._task_state.metadata['usage']) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9bd20f9785..8516feb87d 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -45,7 +45,7 @@ class TaskState(BaseModel): """ NodeExecutionInfo entity """ - workflow_node_execution: WorkflowNodeExecution + workflow_node_execution_id: str start_at: float class Config: @@ -57,7 +57,7 @@ class TaskState(BaseModel): answer: str = "" metadata: dict = {} - workflow_run: Optional[WorkflowRun] = None + workflow_run_id: Optional[str] = None start_at: Optional[float] = None total_tokens: int = 0 total_steps: int = 0 @@ -130,8 +130,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): self._on_node_finished(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -179,8 +178,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(data) break elif isinstance(event, QueueWorkflowStartedEvent): - self._on_workflow_start() - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_start() response = { 'event': 'workflow_started', @@ -195,8 +193,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeStartedEvent): - self._on_node_start(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_start(event) response = { 'event': 'node_started', @@ -214,8 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - self._on_node_finished(event) - workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution + workflow_node_execution = self._on_node_finished(event) response = { 'event': 'node_finished', @@ -240,8 +236,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - self._on_workflow_finished(event) - workflow_run = self._task_state.workflow_run + workflow_run = self._on_workflow_finished(event) # response moderation if self._output_moderation_handler: @@ -257,7 +252,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): replace_response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': self._task_state.answer } @@ -317,7 +312,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): response = { 'event': 'text_replace', 'task_id': self._application_generate_entity.task_id, - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'data': { 'text': event.text } @@ -329,7 +324,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): else: continue - def _on_workflow_start(self) -> None: + def _on_workflow_start(self) -> WorkflowRun: self._task_state.start_at = time.perf_counter() workflow_run = self._init_workflow_run( @@ -344,11 +339,16 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): } ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id - def _on_node_start(self, event: QueueNodeStartedEvent) -> None: + db.session.close() + + return workflow_run + + def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, node_id=event.node_id, node_type=event.node_type, node_title=event.node_data.title, @@ -357,7 +357,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) latest_node_execution_info = TaskState.NodeExecutionInfo( - workflow_node_execution=workflow_node_execution, + workflow_node_execution_id=workflow_node_execution.id, start_at=time.perf_counter() ) @@ -366,11 +366,17 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._task_state.total_steps += 1 - def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, inputs=event.inputs, process_data=event.process_data, @@ -383,19 +389,24 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) else: workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=current_node_execution.workflow_node_execution, + workflow_node_execution=workflow_node_execution, start_at=current_node_execution.start_at, error=event.error ) # remove running node execution info del self._task_state.running_node_execution_infos[event.node_id] - self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution - def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: + db.session.close() + + return workflow_node_execution + + def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ + -> WorkflowRun: + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -404,7 +415,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) elif isinstance(event, QueueWorkflowFailedEvent): workflow_run = self._workflow_run_failed( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, @@ -412,39 +423,30 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): error=event.error ) else: + if self._task_state.latest_node_execution_info: + workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() + outputs = workflow_node_execution.outputs + else: + outputs = None + workflow_run = self._workflow_run_success( - workflow_run=self._task_state.workflow_run, + workflow_run=workflow_run, start_at=self._task_state.start_at, total_tokens=self._task_state.total_tokens, total_steps=self._task_state.total_steps, - outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs - if self._task_state.latest_node_execution_info else None + outputs=outputs ) - self._task_state.workflow_run = workflow_run + self._task_state.workflow_run_id = workflow_run.id if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: outputs = workflow_run.outputs_dict self._task_state.answer = outputs.get('text', '') - def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Get workflow run. - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - return workflow_run + db.session.close() - def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: - """ - Get workflow node execution. - :param workflow_node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = (db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()) - return workflow_node_execution + return workflow_run def _save_workflow_app_log(self) -> None: """ @@ -461,7 +463,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): """ response = { 'event': 'text_chunk', - 'workflow_run_id': self._task_state.workflow_run.id, + 'workflow_run_id': self._task_state.workflow_run_id, 'task_id': self._application_generate_entity.task_id, 'data': { 'text': text diff --git a/api/core/app/apps/workflow_based_generate_task_pipeline.py b/api/core/app/apps/workflow_based_generate_task_pipeline.py index d29cee3ac4..2b373d28e8 100644 --- a/api/core/app/apps/workflow_based_generate_task_pipeline.py +++ b/api/core/app/apps/workflow_based_generate_task_pipeline.py @@ -87,6 +87,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -115,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_run.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_run) db.session.close() return workflow_run @@ -185,6 +187,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution @@ -205,6 +208,7 @@ class WorkflowBasedGenerateTaskPipeline: workflow_node_execution.finished_at = datetime.utcnow() db.session.commit() + db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution From 2d8497f79baebf0c52837eebf96077bb22df6d6d Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 17:11:39 +0800 Subject: [PATCH 143/200] add readme for db connection management in App Runner and Task Pipeline --- api/core/app/apps/README.md | 45 +++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 api/core/app/apps/README.md diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md new file mode 100644 index 0000000000..a59c424a15 --- /dev/null +++ b/api/core/app/apps/README.md @@ -0,0 +1,45 @@ +## Guidelines for Database Connection Management in App Runner and Task Pipeline + +Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. + +Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors. + +Examples: + +1. Creating a new record: + + ```python + app = App(id=1) + db.session.add(app) + db.session.commit() + db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close + + # Process related app logic + + db.session.close() + + return app.id + ``` + +2. Fetching a record from the table: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + created_at = app.created_at + + db.session.close() + ``` + +3. Updating a table field: + + ```python + app = db.session.query(App).filter(App.id == app_id).first() + + app.updated_at = time.utcnow() + db.session.commit() + db.session.close() + + return app_id + ``` + From 1e6feadc7ecc9987cec762befa1d9ccf7f2a9006 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 17:55:24 +0800 Subject: [PATCH 144/200] fix: code node dose not work as expected --- api/core/helper/code_executor/code_executor.py | 14 +++++++------- .../helper/code_executor/python_transformer.py | 10 ++++------ api/core/workflow/nodes/code/code_node.py | 10 +++++----- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index f1bc4fbdaf..fb0ad9642a 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,5 +1,5 @@ from os import environ -from typing import Literal +from typing import Literal, Optional from httpx import post from pydantic import BaseModel @@ -16,8 +16,8 @@ class CodeExecutionException(Exception): class CodeExecutionResponse(BaseModel): class Data(BaseModel): - stdout: str - stderr: str + stdout: Optional[str] + error: Optional[str] code: int message: str @@ -58,9 +58,9 @@ class CodeExecutor: raise Exception('Failed to execute code') except CodeExecutionException as e: raise e - except Exception: + except Exception as e: raise CodeExecutionException('Failed to execute code') - + try: response = response.json() except: @@ -71,7 +71,7 @@ class CodeExecutor: if response.code != 0: raise CodeExecutionException(response.message) - if response.data.stderr: - raise CodeExecutionException(response.data.stderr) + if response.data.error: + raise CodeExecutionException(response.data.error) return template_transformer.transform_response(response.data.stdout) \ No newline at end of file diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index 7b862649d8..27863ee443 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -11,11 +11,11 @@ PYTHON_RUNNER = """# declare main function here output = main(**{{inputs}}) # convert output to json and print -result = ''' -<> +output = json.dumps(output, indent=4) + +result = f'''<> {output} -<> -''' +<>''' print(result) """ @@ -47,11 +47,9 @@ class PythonTemplateTransformer(TemplateTransformer): :param response: response :return: """ - # extract result result = re.search(r'<>(.*)<>', response, re.DOTALL) if not result: raise ValueError('Failed to parse result') - result = result.group(1) return json.loads(result) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 7d3162d983..9cc5865133 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -101,7 +101,6 @@ class CodeNode(BaseNode): ) variables[variable] = value - # Run code try: result = CodeExecutor.execute_code( @@ -109,15 +108,16 @@ class CodeNode(BaseNode): code=code, inputs=variables ) - except CodeExecutionException as e: + + # Transform result + result = self._transform_result(result, node_data.outputs) + except (CodeExecutionException, ValueError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, error=str(e) ) - # Transform result - result = self._transform_result(result, node_data.outputs) - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, From 751489fa547487bd521e4aa3a6bc297b577a2511 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 18:01:55 +0800 Subject: [PATCH 145/200] modify readme --- api/core/app/apps/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md index a59c424a15..856690dc57 100644 --- a/api/core/app/apps/README.md +++ b/api/core/app/apps/README.md @@ -14,7 +14,7 @@ Examples: db.session.commit() db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close - # Process related app logic + # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). db.session.close() @@ -29,6 +29,9 @@ Examples: created_at = app.created_at db.session.close() + + # Handle tasks (include long-running). + ``` 3. Updating a table field: From 80312620064d8f946ed7fbd449aa9f0f82c8a612 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 18:41:01 +0800 Subject: [PATCH 146/200] feat: workflow mock test --- .github/workflows/api-workflow-tests.yaml | 30 +++ api/core/workflow/nodes/code/code_node.py | 10 +- api/tests/integration_tests/.env.example | 6 +- .../integration_tests/workflow/__init__.py | 0 .../workflow/nodes/__mock/code_executor.py | 27 ++ .../workflow/nodes/test_code.py | 244 ++++++++++++++++++ 6 files changed, 311 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/api-workflow-tests.yaml create mode 100644 api/tests/integration_tests/workflow/__init__.py create mode 100644 api/tests/integration_tests/workflow/nodes/__mock/code_executor.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_code.py diff --git a/.github/workflows/api-workflow-tests.yaml b/.github/workflows/api-workflow-tests.yaml new file mode 100644 index 0000000000..e4e35c6c44 --- /dev/null +++ b/.github/workflows/api-workflow-tests.yaml @@ -0,0 +1,30 @@ +name: Run Pytest + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + env: + MOCK_SWITCH: true + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: ./api/requirements.txt + + - name: Install dependencies + run: pip install -r ./api/requirements.txt + + - name: Run pytest + run: pytest api/tests/integration_tests/workflow \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9cc5865133..8034f4e55d 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -132,10 +132,10 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") + raise ValueError(f"{variable} in output form must be a string") if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') + raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters') return value.replace('\x00', '') @@ -147,7 +147,7 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, int | float): - raise ValueError(f"{variable} in input form must be a number") + raise ValueError(f"{variable} in output form must be a number") if value > MAX_NUMBER or value < MIN_NUMBER: raise ValueError(f'{variable} in input form is out of range.') @@ -205,7 +205,7 @@ class CodeNode(BaseNode): if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' + f'{prefix}.{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' ) transformed_result[output_name] = [ @@ -224,7 +224,7 @@ class CodeNode(BaseNode): if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' + f'{prefix}.{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters' ) transformed_result[output_name] = [ diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 04abacf73d..dd1baa79d4 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -66,4 +66,8 @@ JINA_API_KEY= OLLAMA_BASE_URL= # Mock Switch -MOCK_SWITCH=false \ No newline at end of file +MOCK_SWITCH=false + +# CODE EXECUTION CONFIGURATION +CODE_EXECUTION_ENDPOINT= +CODE_EXECUTINO_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/__init__.py b/api/tests/integration_tests/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py new file mode 100644 index 0000000000..b95c76b133 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -0,0 +1,27 @@ +import os +import pytest + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from core.helper.code_executor.code_executor import CodeExecutor + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedCodeExecutor: + @classmethod + def invoke(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + # invoke directly + if language == 'python3': + return { + "result": 3 + } + +@pytest.fixture +def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke) + yield + monkeypatch.undo() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py new file mode 100644 index 0000000000..2885b9f458 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -0,0 +1,244 @@ +import pytest + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.code.code_node import CodeNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowRunStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + 'outputs': { + 'result': { + 'type': 'number', + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] == 3 + assert result.error is None + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code_output_validator(setup_code_executor_mock): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + "outputs": { + "result": { + "type": "string", + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == 'result in output form must be a string' + +def test_execute_code_output_validator_depth(): + code = ''' + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + ''' + # trim first 4 spaces at the beginning of each line + code = '\n'.join([line[4:] for line in code.split('\n')]) + node = CodeNode(config={ + 'id': '1', + 'data': { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + } + } + } + } + } + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } + }) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": "1", + "string_validator": 1, + "number_array_validator": ["1", "2", "3", "3.333"], + "string_array_validator": [1, 2, 3], + "object_validator": { + "result": "1", + "depth": { + "depth": { + "depth": "1" + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1" * 2000, + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333] * 2000, + "string_array_validator": ["1", "2", "3"], + "object_validator": { + "result": 1, + "depth": { + "depth": { + "depth": 1 + } + } + } + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node.node_data.outputs) + \ No newline at end of file From 9d0a832e403654ae73e0857eecffa4aedc077321 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 18:41:49 +0800 Subject: [PATCH 147/200] refactor: github actions --- .github/workflows/{tool-tests.yaml => api-tools-tests.yaml} | 0 .github/workflows/api-workflow-tests.yaml | 1 + 2 files changed, 1 insertion(+) rename .github/workflows/{tool-tests.yaml => api-tools-tests.yaml} (100%) diff --git a/.github/workflows/tool-tests.yaml b/.github/workflows/api-tools-tests.yaml similarity index 100% rename from .github/workflows/tool-tests.yaml rename to .github/workflows/api-tools-tests.yaml diff --git a/.github/workflows/api-workflow-tests.yaml b/.github/workflows/api-workflow-tests.yaml index e4e35c6c44..37a138b44d 100644 --- a/.github/workflows/api-workflow-tests.yaml +++ b/.github/workflows/api-workflow-tests.yaml @@ -4,6 +4,7 @@ on: pull_request: branches: - main + - deploy/dev jobs: test: From be6836998320c2428d3a1a9003b1ad8688c3ecbd Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 20:02:10 +0800 Subject: [PATCH 148/200] add workflow_app_log codes --- .../apps/workflow/generate_task_pipeline.py | 40 ++++++++++++++++--- api/models/workflow.py | 23 +++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 8516feb87d..7a244151f2 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -32,7 +32,15 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVaria from extensions.ext_database import db from models.account import Account from models.model import EndUser -from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom +from models.workflow import ( + Workflow, + WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowRun, + WorkflowRunStatus, + WorkflowRunTriggeredFrom, +) logger = logging.getLogger(__name__) @@ -142,7 +150,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ) # save workflow app log - self._save_workflow_app_log() + self._save_workflow_app_log(workflow_run) response = { 'task_id': self._application_generate_entity.task_id, @@ -261,7 +269,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(replace_response) # save workflow app log - self._save_workflow_app_log() + self._save_workflow_app_log(workflow_run) workflow_run_response = { 'event': 'workflow_finished', @@ -448,12 +456,34 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): return workflow_run - def _save_workflow_app_log(self) -> None: + def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: """ - pass # todo + invoke_from = self._application_generate_entity.invoke_from + if invoke_from == InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + elif invoke_from == InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + elif invoke_from == InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + else: + # not save log for debugging + return + + workflow_app_log = WorkflowAppLog( + tenant_id=workflow_run.tenant_id, + app_id=workflow_run.app_id, + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + created_from=created_from.value, + created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), + created_by=self._user.id, + ) + db.session.add(workflow_app_log) + db.session.commit() + db.session.close() def _handle_chunk(self, text: str) -> dict: """ diff --git a/api/models/workflow.py b/api/models/workflow.py index 9768c364dd..5a3cdcf83c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -433,6 +433,29 @@ class WorkflowNodeExecution(db.Model): def execution_metadata_dict(self): return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) + +class WorkflowAppLogCreatedFrom(Enum): + """ + Workflow App Log Created From Enum + """ + SERVICE_API = 'service-api' + WEB_APP = 'web-app' + INSTALLED_APP = 'installed-app' + + @classmethod + def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid workflow app log created from value {value}') + + class WorkflowAppLog(db.Model): """ Workflow App execution log, excluding workflow debugging records. From a0a161886938d5d77521038daf7b4a58e07fd57b Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 20:15:49 +0800 Subject: [PATCH 149/200] add tenant_id / app_id / workflow_id for nodes --- api/core/workflow/entities/workflow_entities.py | 14 +++++++++++--- api/core/workflow/nodes/base_node.py | 13 ++++++++++++- api/core/workflow/workflow_engine_manager.py | 17 ++++++++++++++--- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 768ad6a130..91f9ef95fe 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,7 +3,7 @@ from typing import Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType class WorkflowNodeAndResult: @@ -16,7 +16,11 @@ class WorkflowNodeAndResult: class WorkflowRunState: - workflow: Workflow + tenant_id: str + app_id: str + workflow_id: str + workflow_type: WorkflowType + start_at: float variable_pool: VariablePool @@ -25,6 +29,10 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): - self.workflow = workflow + self.workflow_id = workflow.id + self.tenant_id = workflow.tenant_id + self.app_id = workflow.app_id + self.workflow_type = WorkflowType.value_of(workflow.type) + self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3f2e806433..6db25bea7e 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -12,14 +12,25 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType + tenant_id: str + app_id: str + workflow_id: str + node_id: str node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None callbacks: list[BaseWorkflowCallback] - def __init__(self, config: dict, + def __init__(self, tenant_id: str, + app_id: str, + workflow_id: str, + config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.node_id = config.get("id") if not self.node_id: raise ValueError("Node ID is required.") diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 50f79df1f0..d01746ceb8 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -122,6 +122,7 @@ class WorkflowEngineManager: while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( + workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, callbacks=callbacks @@ -198,7 +199,8 @@ class WorkflowEngineManager: error=error ) - def _get_next_node(self, graph: dict, + def _get_next_node(self, workflow_run_state: WorkflowRunState, + graph: dict, predecessor_node: Optional[BaseNode] = None, callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: """ @@ -216,7 +218,13 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: if node_config.get('data', {}).get('type', '') == NodeType.START.value: - return StartNode(config=node_config) + return StartNode( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + config=node_config, + callbacks=callbacks + ) else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -256,6 +264,9 @@ class WorkflowEngineManager: target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, config=target_node_config, callbacks=callbacks ) @@ -354,7 +365,7 @@ class WorkflowEngineManager: :param node_run_result: node run result :return: """ - if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] if workflow_nodes_and_result_before_end: if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: From e0883302d26262e16456ceef14a79a4837b61cb5 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 20:24:16 +0800 Subject: [PATCH 150/200] feat: jinja2 --- .../helper/code_executor/code_executor.py | 7 ++- .../helper/code_executor/jina2_transformer.py | 55 ++++++++++++++++++- .../template_transform_node.py | 6 +- .../workflow/nodes/__mock/code_executor.py | 2 +- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index fb0ad9642a..a62cf4de95 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from typing import Literal, Optional from httpx import post from pydantic import BaseModel from yarl import URL +from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer @@ -25,7 +26,7 @@ class CodeExecutionResponse(BaseModel): class CodeExecutor: @classmethod - def execute_code(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -36,6 +37,8 @@ class CodeExecutor: template_transformer = None if language == 'python3': template_transformer = PythonTemplateTransformer + elif language == 'jinja2': + template_transformer = Jinja2TemplateTransformer else: raise CodeExecutionException('Unsupported language') @@ -46,7 +49,7 @@ class CodeExecutor: 'X-Api-Key': CODE_EXECUTION_API_KEY } data = { - 'language': language, + 'language': language if language != 'jinja2' else 'python3', 'code': runner, } diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py index f87f5c14cb..87e8ce130f 100644 --- a/api/core/helper/code_executor/jina2_transformer.py +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -1 +1,54 @@ -# TODO \ No newline at end of file +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +PYTHON_RUNNER = """ +import jinja2 + +template = jinja2.Template('''{{code}}''') + +def main(**inputs): + return template.render(**inputs) + +# execute main function, and return the result +output = main(**{{inputs}}) + +result = f'''<>{output}<>''' + +print(result) + +""" + +class Jinja2TemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform jinja2 template to python code + runner = PYTHON_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4)) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + + return { + 'result': result + } \ No newline at end of file diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 724b84495c..a037332f4b 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -52,7 +52,7 @@ class TemplateTransformNode(BaseNode): # Run code try: result = CodeExecutor.execute_code( - language='jina2', + language='jinja2', code=node_data.template, inputs=variables ) @@ -66,7 +66,9 @@ class TemplateTransformNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - outputs=result['result'] + outputs={ + 'output': result['result'] + } ) @classmethod diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index b95c76b133..a1c8eb71dc 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -9,7 +9,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal['python3', 'javascript', 'jina2'], code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: # invoke directly if language == 'python3': return { From f8cba2679e4ce24667a8f365bf81631b71e5c156 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Sun, 10 Mar 2024 21:12:07 +0800 Subject: [PATCH 151/200] fix: linter --- api/core/helper/code_executor/code_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index a62cf4de95..21a8ca5f9f 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,8 +4,8 @@ from typing import Literal, Optional from httpx import post from pydantic import BaseModel from yarl import URL -from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer +from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer # Code Executor From 5e4bd9fc38ba406569099e9b49965a80ae5ef615 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 13:54:11 +0800 Subject: [PATCH 152/200] feat: tool node --- api/core/agent/base_agent_runner.py | 69 ---------- api/core/agent/cot_agent_runner.py | 8 +- api/core/agent/fc_agent_runner.py | 8 +- api/core/tools/tool_manager.py | 114 ++++++++++------ api/core/tools/utils/message_transformer.py | 85 ++++++++++++ api/core/workflow/nodes/tool/entities.py | 23 ++++ api/core/workflow/nodes/tool/tool_node.py | 136 +++++++++++++++++++- 7 files changed, 334 insertions(+), 109 deletions(-) create mode 100644 api/core/tools/utils/message_transformer.py create mode 100644 api/core/workflow/nodes/tool/entities.py diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0901b7e965..14602a7265 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,6 @@ import json import logging import uuid from datetime import datetime -from mimetypes import guess_extension from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -39,7 +38,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool -from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager from extensions.ext_database import db from models.model import Message, MessageAgentThought, MessageFile @@ -462,73 +460,6 @@ class BaseAgentRunner(AppRunner): db.session.commit() db.session.close() - - def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: - """ - Transform tool message into agent thought - """ - result = [] - - for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: - # try to download image - try: - file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_url=message.message) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - except Exception as e: - logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') - # if message is str, encode it to bytes - if isinstance(message.message, str): - message.message = message.message.encode('utf-8') - file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_binary=message.message, - mimetype=mimetype) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' - - # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(message) - - return result def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index cbb19aca53..0c5399f541 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,6 +25,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message @@ -280,7 +281,12 @@ class CotAgentRunner(BaseAgentRunner): tool_parameters=tool_call_args ) # transform tool response to llm friendly response - tool_response = self.transform_tool_invoke_messages(tool_response) + tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_response, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_response) # create message file diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 7c3849a12c..185d7684c8 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -23,6 +23,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -270,7 +271,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_parameters=tool_call_args, ) # transform tool invoke message to get LLM friendly message - tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) + tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_invoke_message, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_invoke_message) # create message file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 24b2f287c1..ea66362195 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -34,6 +34,7 @@ from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) from core.tools.utils.encoder import serialize_base_model_dict +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -225,6 +226,48 @@ class ToolManager: else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + """ + init runtime parameter + """ + parameter_value = parameters.get(parameter_rule.name) + if not parameter_value: + # get default value + parameter_value = parameter_rule.default + if not parameter_value and parameter_rule.required: + raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") + + if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter_rule.options)) + if parameter_value not in options: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(parameter_value, int): + parameter_value = parameter_value + elif isinstance(parameter_value, float): + parameter_value = parameter_value + elif isinstance(parameter_value, str): + if '.' in parameter_value: + parameter_value = float(parameter_value) + else: + parameter_value = int(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: + parameter_value = bool(parameter_value) + elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + parameter_value = str(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType: + parameter_value = str(parameter_value) + except Exception as e: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") + + return parameter_value + @staticmethod def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: """ @@ -239,44 +282,9 @@ class ToolManager: parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config + value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -289,6 +297,38 @@ class ToolManager: tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity + + @staticmethod + def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler): + """ + get the workflow tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_name=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py new file mode 100644 index 0000000000..3f456b4eb6 --- /dev/null +++ b/api/core/tools/utils/message_transformer.py @@ -0,0 +1,85 @@ +import logging +from mimetypes import guess_extension + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + +class ToolFileMessageTransformer: + @staticmethod + def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + user_id: str, + tenant_id: str, + conversation_id: str) -> list[ToolInvokeMessage]: + """ + Transform tool message and handle file download + """ + result = [] + + for message in messages: + if message.type == ToolInvokeMessage.MessageType.TEXT: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.LINK: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.IMAGE: + # try to download image + try: + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_url=message.message + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' + + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + except Exception as e: + logger.exception(e) + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + )) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + mimetype = message.meta.get('mime_type', 'octet/stream') + # if message is str, encode it to bytes + if isinstance(message.message, str): + message.message = message.message.encode('utf-8') + + file = ToolFileManager.create_file_by_raw( + user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message, + mimetype=mimetype + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + + # check if file is image + if 'image' in mimetype: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(message) + + return result \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py new file mode 100644 index 0000000000..e782bd3004 --- /dev/null +++ b/api/core/workflow/nodes/tool/entities.py @@ -0,0 +1,23 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + +ToolParameterValue = Union[str, int, float, bool] + +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_parameters: dict[str, ToolParameterValue] + + +class ToolNodeData(BaseNodeData, ToolEntity): + """ + Tool Node Schema + """ + tool_inputs: list[VariableSelector] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b805a53d2f..a0b0991eb6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,139 @@ +from os import path +from typing import cast + +from core.file.file_obj import FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.tool.entities import ToolNodeData +from models.workflow import WorkflowNodeExecutionStatus class ToolNode(BaseNode): - pass + """ + Tool Node + """ + _node_data_cls = ToolNodeData + _node_type = NodeType.TOOL + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the tool node + """ + + node_data = cast(ToolNodeData, self.node_data) + + # extract tool parameters + parameters = { + k.variable: variable_pool.get_variable_value(k.value_selector) + for k in node_data.tool_inputs + } + + if len(parameters) != len(node_data.tool_inputs): + raise ValueError('Invalid tool parameters') + + # get tool runtime + try: + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to get tool runtime: {str(e)}' + ) + + try: + messages = tool_runtime.invoke(None, parameters) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to invoke tool: {str(e)}' + ) + + # convert tool messages + plain_text, files = self._convert_tool_messages(messages) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCESS, + outputs={ + 'text': plain_text, + 'files': files + }, + ) + + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages) + # extract plain text and files + files = self._extract_tool_response_binary(messages) + plain_text = self._extract_tool_response_text(messages) + + return plain_text, files + + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + url = response.message + ext = path.splitext(url)[1] + mimetype = response.meta.get('mime_type', 'image/jpeg') + filename = response.save_as or url.split('/')[-1] + result.append({ + 'type': 'image', + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': url, + 'upload_file_id': None, + 'filename': filename, + 'file-ext': ext, + 'mime-type': mimetype, + }) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append({ + 'type': 'image', # TODO: only support image for now + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': response.message, + 'upload_file_id': None, + 'filename': response.save_as, + 'file-ext': path.splitext(response.save_as)[1], + 'mime-type': response.meta.get('mime_type', 'application/octet-stream'), + }) + elif response.type == ToolInvokeMessage.MessageType.LINK: + pass # TODO: + + return result + + def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: + """ + Extract tool response text + """ + return ''.join([ + f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' + for message in tool_response + ]) + + def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: + """ + Convert ToolInvokeMessage into file + """ + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + """ + pass \ No newline at end of file From 5eb7b4d56a93acbd9cbfbd52af44d84e0ab3d76a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:13:52 +0800 Subject: [PATCH 153/200] feat: tool entity --- api/core/tools/tool_manager.py | 2 +- api/core/workflow/nodes/tool/entities.py | 19 +++++++++++---- api/core/workflow/nodes/tool/tool_node.py | 29 ++++++++++++----------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ea66362195..52e1e71d82 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -315,7 +315,7 @@ class ToolManager: for parameter in parameters: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) runtime_parameters[parameter.name] = value # decrypt runtime parameters diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index e782bd3004..0b3bf76aac 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,6 +1,6 @@ -from typing import Literal, Union +from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -13,11 +13,20 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_parameters: dict[str, ToolParameterValue] - + tool_configurations: dict[str, ToolParameterValue] class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(VariableSelector): + variable_type: Literal['selector', 'static'] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + """ Tool Node Schema """ - tool_inputs: list[VariableSelector] + tool_parameters: list[ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a0b0991eb6..f1897780f2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -27,14 +27,8 @@ class ToolNode(BaseNode): node_data = cast(ToolNodeData, self.node_data) - # extract tool parameters - parameters = { - k.variable: variable_pool.get_variable_value(k.value_selector) - for k in node_data.tool_inputs - } - - if len(parameters) != len(node_data.tool_inputs): - raise ValueError('Invalid tool parameters') + # get parameters + parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: @@ -47,6 +41,7 @@ class ToolNode(BaseNode): ) try: + # TODO: user_id messages = tool_runtime.invoke(None, parameters) except Exception as e: return NodeRunResult( @@ -59,12 +54,23 @@ class ToolNode(BaseNode): plain_text, files = self._convert_tool_messages(messages) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCESS, + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ 'text': plain_text, 'files': files }, ) + + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + """ + Generate parameters + """ + return { + k.variable: + k.value if k.variable_type == 'static' else + variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else '' + for k in node_data.tool_parameters + } def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: """ @@ -125,11 +131,6 @@ class ToolNode(BaseNode): for message in tool_response ]) - def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: - """ - Convert ToolInvokeMessage into file - """ - pass @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: From 7a6fa3655f648935a5e0b82d4458a1263a98734f Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 16:31:43 +0800 Subject: [PATCH 154/200] add user for node --- api/core/app/apps/advanced_chat/app_runner.py | 6 +++++ api/core/app/apps/workflow/app_runner.py | 6 +++++ .../workflow/entities/workflow_entities.py | 12 +++++++-- api/core/workflow/nodes/base_node.py | 27 +++++++++++++++++++ api/core/workflow/workflow_engine_manager.py | 14 ++++++++-- .../unit_tests/core/workflow/__init__.py | 0 6 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/__init__.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c42620b92f..5f5fd7010c 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -8,10 +8,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, Conversation, Message @@ -78,6 +80,10 @@ class AdvancedChatAppRunner(AppRunner): workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2d032fcdcb..922c3003bf 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -7,12 +7,14 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, + InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable +from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App @@ -63,6 +65,10 @@ class WorkflowAppRunner: workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager.run_workflow( workflow=workflow, + user_id=application_generate_entity.user_id, + user_from=UserFrom.ACCOUNT + if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 91f9ef95fe..a78bf09a53 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -2,7 +2,7 @@ from typing import Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from models.workflow import Workflow, WorkflowType @@ -20,6 +20,8 @@ class WorkflowRunState: app_id: str workflow_id: str workflow_type: WorkflowType + user_id: str + user_from: UserFrom start_at: float variable_pool: VariablePool @@ -28,11 +30,17 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] - def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): + def __init__(self, workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id self.workflow_type = WorkflowType.value_of(workflow.type) + self.user_id = user_id + self.user_from = user_from self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 6db25bea7e..a603f484ef 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Optional from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback @@ -8,6 +9,26 @@ from core.workflow.entities.variable_pool import VariablePool from models.workflow import WorkflowNodeExecutionStatus +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") + + class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType @@ -15,6 +36,8 @@ class BaseNode(ABC): tenant_id: str app_id: str workflow_id: str + user_id: str + user_from: UserFrom node_id: str node_data: BaseNodeData @@ -25,11 +48,15 @@ class BaseNode(ABC): def __init__(self, tenant_id: str, app_id: str, workflow_id: str, + user_id: str, + user_from: UserFrom, config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: self.tenant_id = tenant_id self.app_id = app_id self.workflow_id = workflow_id + self.user_id = user_id + self.user_from = user_from self.node_id = config.get("id") if not self.node_id: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index d01746ceb8..0bc13cbb5a 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,7 +6,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode @@ -76,12 +76,16 @@ class WorkflowEngineManager: return default_config def run_workflow(self, workflow: Workflow, + user_id: str, + user_from: UserFrom, user_inputs: dict, system_inputs: Optional[dict] = None, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Run workflow :param workflow: Workflow instance + :param user_id: user id + :param user_from: user from :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param callbacks: workflow callbacks @@ -113,7 +117,9 @@ class WorkflowEngineManager: variable_pool=VariablePool( system_variables=system_inputs, user_inputs=user_inputs - ) + ), + user_id=user_id, + user_from=user_from ) try: @@ -222,6 +228,8 @@ class WorkflowEngineManager: tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=node_config, callbacks=callbacks ) @@ -267,6 +275,8 @@ class WorkflowEngineManager: tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, config=target_node_config, callbacks=callbacks ) diff --git a/api/tests/unit_tests/core/workflow/__init__.py b/api/tests/unit_tests/core/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From f911b1c488ccc18eaf274a6fa4c4869f57b6cf21 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:44:22 +0800 Subject: [PATCH 155/200] feat: support empty code output children --- api/core/workflow/nodes/code/code_node.py | 53 ++++- api/core/workflow/nodes/code/entities.py | 4 +- .../workflow/nodes/test_code.py | 206 ++++++++++-------- 3 files changed, 167 insertions(+), 96 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 8034f4e55d..bfdec73199 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -153,11 +153,13 @@ class CodeNode(BaseNode): raise ValueError(f'{variable} in input form is out of range.') if isinstance(value, float): - value = round(value, MAX_PRECISION) + # raise error if precision is too high + if len(str(value).split('.')[1]) > MAX_PRECISION: + raise ValueError(f'{variable} in output form has too high precision.') return value - def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], + def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = '', depth: int = 1) -> dict: """ @@ -170,6 +172,47 @@ class CodeNode(BaseNode): raise ValueError("Depth limit reached, object too deep.") transformed_result = {} + if output_schema is None: + # validate output thought instance type + for output_name, output_value in result.items(): + if isinstance(output_value, dict): + self._transform_result( + result=output_value, + output_schema=None, + prefix=f'{prefix}.{output_name}' if prefix else output_name, + depth=depth + 1 + ) + elif isinstance(output_value, (int, float)): + self._check_number( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, str): + self._check_string( + value=output_value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif isinstance(output_value, list): + if all(isinstance(value, (int, float)) for value in output_value): + for value in output_value: + self._check_number( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + elif all(isinstance(value, str) for value in output_value): + for value in output_value: + self._check_string( + value=value, + variable=f'{prefix}.{output_name}' if prefix else output_name + ) + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + else: + raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') + + return result + + parameters_validated = {} for output_name, output_config in output_schema.items(): if output_config.type == 'object': # check if output is object @@ -236,6 +279,12 @@ class CodeNode(BaseNode): ] else: raise ValueError(f'Output type {output_config.type} is not supported.') + + parameters_validated[output_name] = True + + # check if all output parameters are validated + if len(parameters_validated) != len(result): + raise ValueError('Not all output parameters are validated.') return transformed_result diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 6a18d181cb..ec3e3fe530 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal, Optional from pydantic import BaseModel @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Union[None, dict[str, 'Output']] + children: Optional[dict[str, 'Output']] variables: list[VariableSelector] answer: str diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 2885b9f458..0b7217b053 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,8 +1,9 @@ import pytest +from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.code.code_node import CodeNode -from models.workflow import WorkflowNodeExecutionStatus, WorkflowRunStatus +from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) @@ -15,30 +16,37 @@ def test_execute_code(setup_code_executor_mock): ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - 'outputs': { - 'result': { - 'type': 'number', + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'outputs': { + 'result': { + 'type': 'number', + }, }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}) @@ -61,30 +69,37 @@ def test_execute_code_output_validator(setup_code_executor_mock): ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - "outputs": { - "result": { - "type": "string", + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "result": { + "type": "string", + }, }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}) @@ -108,60 +123,67 @@ def test_execute_code_output_validator_depth(): ''' # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) - node = CodeNode(config={ - 'id': '1', - 'data': { - "outputs": { - "string_validator": { - "type": "string", - }, - "number_validator": { - "type": "number", - }, - "number_array_validator": { - "type": "array[number]", - }, - "string_array_validator": { - "type": "array[string]", - }, - "object_validator": { - "type": "object", - "children": { - "result": { - "type": "number", - }, - "depth": { - "type": "object", - "children": { - "depth": { - "type": "object", - "children": { - "depth": { - "type": "number", + node = CodeNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } } } } } } + }, + }, + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] } - }, - }, - 'title': '123', - 'variables': [ - { - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, - { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'] - } - ], - 'answer': '123', - 'code_language': 'python3', - 'code': code + ], + 'answer': '123', + 'code_language': 'python3', + 'code': code + } } - }) + ) # construct result result = { From 91845fc9f6e652b1f6dd327abfc3870df373c295 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:44:36 +0800 Subject: [PATCH 156/200] fix: linter --- api/core/workflow/nodes/code/code_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index bfdec73199..2f22a386e5 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -182,7 +182,7 @@ class CodeNode(BaseNode): prefix=f'{prefix}.{output_name}' if prefix else output_name, depth=depth + 1 ) - elif isinstance(output_value, (int, float)): + elif isinstance(output_value, int | float): self._check_number( value=output_value, variable=f'{prefix}.{output_name}' if prefix else output_name @@ -193,7 +193,7 @@ class CodeNode(BaseNode): variable=f'{prefix}.{output_name}' if prefix else output_name ) elif isinstance(output_value, list): - if all(isinstance(value, (int, float)) for value in output_value): + if all(isinstance(value, int | float) for value in output_value): for value in output_value: self._check_number( value=value, From 407bfb8182ee32c2057ae2081c2d8dbc895d5c01 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:46:11 +0800 Subject: [PATCH 157/200] feat: add user uid --- api/core/workflow/nodes/tool/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index f1897780f2..b0bc1246bd 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -42,7 +42,7 @@ class ToolNode(BaseNode): try: # TODO: user_id - messages = tool_runtime.invoke(None, parameters) + messages = tool_runtime.invoke(self.user_id, parameters) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, From f318fa058ccd95cb996c64663dbfcf4a1271e220 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:48:28 +0800 Subject: [PATCH 158/200] feat: add variable selector mapping --- api/core/workflow/nodes/tool/tool_node.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b0bc1246bd..bfa7db3943 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -137,4 +137,8 @@ class ToolNode(BaseNode): """ Extract variable selector to variable mapping """ - pass \ No newline at end of file + return { + k.value_selector: k.variable + for k in cast(ToolNodeData, node_data).tool_parameters + if k.variable_type == 'selector' + } \ No newline at end of file From 88c29f613f8d01be1fcb01a0a1ba8bfee78cb6f7 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:51:27 +0800 Subject: [PATCH 159/200] fix: typing --- api/core/workflow/nodes/code/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index ec3e3fe530..0e2b3c99bf 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Optional[dict[str, 'Output']] + children: Optional[dict[str, 'CodeNodeData.Output']] variables: list[VariableSelector] answer: str From 33113034ea6ad02a8b59f5efe7645824ad6bedc3 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 18:49:22 +0800 Subject: [PATCH 160/200] add single step run --- api/controllers/console/__init__.py | 2 +- api/controllers/console/app/workflow.py | 23 +++-- api/core/workflow/errors.py | 10 +++ api/core/workflow/nodes/base_node.py | 4 +- api/core/workflow/nodes/code/code_node.py | 6 +- .../nodes/direct_answer/direct_answer_node.py | 10 ++- api/core/workflow/nodes/end/end_node.py | 2 +- .../nodes/http_request/http_request_node.py | 6 +- api/core/workflow/nodes/llm/llm_node.py | 2 +- api/core/workflow/nodes/start/start_node.py | 2 +- .../template_transform_node.py | 4 +- api/core/workflow/nodes/tool/tool_node.py | 6 +- api/core/workflow/workflow_engine_manager.py | 88 +++++++++++++++++++ api/fields/workflow_run_fields.py | 8 +- api/services/workflow_run_service.py | 14 +-- api/services/workflow_service.py | 86 +++++++++++++++++- 16 files changed, 233 insertions(+), 40 deletions(-) create mode 100644 api/core/workflow/errors.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a6f803785a..853ca9e3a7 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -8,7 +8,7 @@ api = ExternalApi(bp) from . import admin, apikey, extension, feature, setup, version, ping # Import app controllers from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic, workflow, workflow_app_log) + model_config, site, statistic, workflow, workflow_run, workflow_app_log) # Import auth controllers from .auth import activate, data_source_oauth, login, oauth # Import billing controllers diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5f03a7cd37..6f81da5691 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -15,6 +15,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.entities.app_invoke_entities import InvokeFrom from fields.workflow_fields import workflow_fields +from fields.workflow_run_fields import workflow_run_node_execution_fields from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode @@ -164,18 +165,24 @@ class DraftWorkflowNodeRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_node_execution_fields) def post(self, app_model: App, node_id: str): """ Run draft workflow node """ - # TODO - workflow_service = WorkflowService() - workflow_service.run_draft_workflow_node(app_model=app_model, node_id=node_id, account=current_user) + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() - # TODO - return { - "result": "success" - } + workflow_service = WorkflowService() + workflow_node_execution = workflow_service.run_draft_workflow_node( + app_model=app_model, + node_id=node_id, + user_inputs=args.get('inputs'), + account=current_user + ) + + return workflow_node_execution class PublishedWorkflowApi(Resource): @@ -291,7 +298,7 @@ api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') +api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py new file mode 100644 index 0000000000..fe79fadf66 --- /dev/null +++ b/api/core/workflow/errors.py @@ -0,0 +1,10 @@ +from core.workflow.entities.node_entities import NodeType + + +class WorkflowNodeRunFailedError(Exception): + def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): + self.node_id = node_id + self.node_type = node_type + self.node_title = node_title + self.error = error + super().__init__(f"Node {node_title} run failed: {error}") diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index a603f484ef..dfba9d0385 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -108,7 +108,7 @@ class BaseNode(ABC): ) @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict: + def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param config: node config @@ -119,7 +119,7 @@ class BaseNode(ABC): @classmethod @abstractmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2f22a386e5..2c11e5ba00 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -289,7 +289,7 @@ class CodeNode(BaseNode): return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data @@ -297,5 +297,5 @@ class CodeNode(BaseNode): """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables - } \ No newline at end of file + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index 9193bab9ee..fedbc9b2d1 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -50,10 +50,16 @@ class DirectAnswerNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - return {} + node_data = cast(cls._node_data_cls, node_data) + + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + return variable_mapping diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 65b0b86aa0..2666ccc4f9 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -56,7 +56,7 @@ class EndNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 4ee76deb83..853f8fe5e3 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -48,12 +48,12 @@ class HttpRequestNode(BaseNode): @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables - } \ No newline at end of file + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 90a7755b85..41e28937ac 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -23,7 +23,7 @@ class LLMNode(BaseNode): pass @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2321e04bd4..08171457fb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -69,7 +69,7 @@ class StartNode(BaseNode): return filtered_inputs @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index a037332f4b..c41f5d1030 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -72,12 +72,12 @@ class TemplateTransformNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ return { - variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index bfa7db3943..69a97fc206 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -133,12 +133,12 @@ class ToolNode(BaseNode): @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping """ return { - k.value_selector: k.variable + k.variable: k.value_selector for k in cast(ToolNodeData, node_data).tool_parameters if k.variable_type == 'selector' - } \ No newline at end of file + } diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 0bc13cbb5a..17225c19ea 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -6,6 +6,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode @@ -180,6 +181,93 @@ class WorkflowEngineManager: callbacks=callbacks ) + def single_step_run_workflow_node(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError('workflow graph not found') + + nodes = graph.get('nodes') + if not nodes: + raise ValueError('nodes not found in workflow graph') + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get('id') == node_id: + node_config = node + break + + if not node_config: + raise ValueError('node id not found in workflow graph') + + # Get node class + node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + + # init workflow run state + node_instance = node_cls( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + config=node_config + ) + + try: + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={} + ) + + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + except NotImplementedError: + variable_mapping = {} + + for variable_key, variable_selector in variable_mapping.items(): + if variable_key not in user_inputs: + raise ValueError(f'Variable key {variable_key} not found in user inputs.') + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + + # append variable and value to variable pool + variable_pool.append_variable( + node_id=variable_node_id, + variable_key_list=variable_key_list, + value=user_inputs.get(variable_key) + ) + + # run node + node_run_result = node_instance.run( + variable_pool=variable_pool + ) + except Exception as e: + raise WorkflowNodeRunFailedError( + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_title=node_instance.node_data.title, + error=str(e) + ) + + return node_instance, node_run_result + + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 572f472f1f..3135d91fd3 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -34,11 +34,9 @@ workflow_run_for_list_fields = { } workflow_run_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='items') + 'limit': fields.Integer(attribute='limit'), + 'has_more': fields.Boolean(attribute='has_more'), + 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') } workflow_run_detail_fields = { diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 70ce1f2ce0..1d3f93f224 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -34,26 +34,26 @@ class WorkflowRunService: if not last_workflow_run: raise ValueError('Last workflow run not exists') - conversations = base_query.filter( + workflow_runs = base_query.filter( WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() else: - conversations = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() + workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() has_more = False - if len(conversations) == limit: - current_page_first_conversation = conversations[-1] + if len(workflow_runs) == limit: + current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_conversation.created_at, - WorkflowRun.id != current_page_first_conversation.id + WorkflowRun.created_at < current_page_first_workflow_run.created_at, + WorkflowRun.id != current_page_first_workflow_run.id ).count() if rest_count > 0: has_more = True return InfiniteScrollPagination( - data=conversations, + data=workflow_runs, limit=limit, has_more=has_more ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f8bd80a0b1..2c9c07106c 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,4 +1,5 @@ import json +import time from collections.abc import Generator from datetime import datetime from typing import Optional, Union @@ -9,12 +10,21 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType +from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.account import Account from models.model import App, AppMode, EndUser -from models.workflow import Workflow, WorkflowType +from models.workflow import ( + CreatedByRole, + Workflow, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, + WorkflowType, +) from services.workflow.workflow_converter import WorkflowConverter @@ -214,6 +224,80 @@ class WorkflowService: """ AppQueueManager.set_stop_flag(task_id, invoke_from, user.id) + def run_draft_workflow_node(self, app_model: App, + node_id: str, + user_inputs: dict, + account: Account) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # fetch draft workflow by app_model + draft_workflow = self.get_draft_workflow(app_model=app_model) + if not draft_workflow: + raise ValueError('Workflow not initialized') + + # run draft workflow node + workflow_engine_manager = WorkflowEngineManager() + start_at = time.perf_counter() + + try: + node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + except WorkflowNodeRunFailedError as e: + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=e.node_id, + node_type=e.node_type.value, + title=e.node_title, + status=WorkflowNodeExecutionStatus.FAILED.value, + error=e.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + def convert_to_workflow(self, app_model: App, account: Account) -> App: """ Basic mode of chatbot app(expert mode) to workflow From f2bb0012fdc980c989c0805a27deedc35ad06388 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 18:52:24 +0800 Subject: [PATCH 161/200] add debug code --- api/core/workflow/nodes/direct_answer/direct_answer_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/direct_answer/direct_answer_node.py index fedbc9b2d1..22ef2ed53b 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/direct_answer/direct_answer_node.py @@ -39,7 +39,7 @@ class DirectAnswerNode(BaseNode): # publish answer as stream for word in answer: self.publish_text_chunk(word) - time.sleep(0.01) + time.sleep(10) # TODO for debug return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, From 7f7269d261349027dd93661b0d82c6f71ab5bef7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 11 Mar 2024 19:04:48 +0800 Subject: [PATCH 162/200] remove unused params in workflow_run_for_list_fields --- api/fields/workflow_run_fields.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3135d91fd3..72510cd27a 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -20,11 +20,7 @@ workflow_run_for_list_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), - "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, From 7372776992ac2fd2e5976be1a5396ad6503e06ea Mon Sep 17 00:00:00 2001 From: jyong Date: Mon, 11 Mar 2024 20:06:38 +0800 Subject: [PATCH 163/200] knowledge node --- .../knowledge_retrieval/knowledge_retrieval_node.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index c6dd624921..7b8344418b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,13 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode class KnowledgeRetrievalNode(BaseNode): - pass + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + pass From ebf9c41adb68008d88f61b896bdebdf84ae337f4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 18:02:20 +0800 Subject: [PATCH 164/200] feat: http --- api/core/helper/ssrf_proxy.py | 1 + .../workflow/nodes/http_request/entities.py | 4 +- .../nodes/http_request/http_executor.py | 82 +++++++++++-------- .../nodes/http_request/http_request_node.py | 4 +- .../workflow/nodes/__mock/http.py | 82 +++++++++++++++++++ .../workflow/nodes/test_http.py | 51 ++++++++++++ 6 files changed, 188 insertions(+), 36 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/__mock/http.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_http.py diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index c44d4717e6..22f5fe57e0 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -26,6 +26,7 @@ httpx_proxies = { } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None def get(url, *args, **kwargs): + print(url, kwargs) return _get(url=url, *args, proxies=httpx_proxies, **kwargs) def post(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 1e906cbaa4..ce806b6bdb 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal, Optional, Union from pydantic import BaseModel @@ -29,4 +29,4 @@ class HttpRequestNodeData(BaseNodeData): authorization: Authorization headers: str params: str - body: Body \ No newline at end of file + body: Optional[Body] \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 82d879a89c..6134a7d780 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -76,11 +76,17 @@ class HttpExecutor: # fill in params kv_paris = original_params.split('\n') for kv in kv_paris: + if not kv.strip(): + continue + kv = kv.split(':') - if len(kv) != 2: + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: raise ValueError(f'Invalid params {kv}') - k, v = kv self.params[k] = v # extract all template in headers @@ -96,51 +102,61 @@ class HttpExecutor: # fill in headers kv_paris = original_headers.split('\n') for kv in kv_paris: + if not kv.strip(): + continue + kv = kv.split(':') - if len(kv) != 2: + if len(kv) == 2: + k, v = kv + elif len(kv) == 1: + k, v = kv[0], '' + else: raise ValueError(f'Invalid headers {kv}') - k, v = kv self.headers[k] = v # extract all template in body - body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] - body_template = list(set(body_template)) - original_body = node_data.body.data or '' - for body in body_template: - if not body: - continue + if node_data.body: + body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] + body_template = list(set(body_template)) + original_body = node_data.body.data or '' + for body in body_template: + if not body: + continue - original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) + original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, ''))) - if node_data.body.type == 'json': - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded': - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' - # elif node_data.body.type == 'form-data': - # self.headers['Content-Type'] = 'multipart/form-data' + if node_data.body.type == 'json': + self.headers['Content-Type'] = 'application/json' + elif node_data.body.type == 'x-www-form-urlencoded': + self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + # elif node_data.body.type == 'form-data': + # self.headers['Content-Type'] = 'multipart/form-data' - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: - body = {} - kv_paris = original_body.split('\n') - for kv in kv_paris: - kv = kv.split(':') - if len(kv) != 2: - raise ValueError(f'Invalid body {kv}') - body[kv[0]] = kv[1] + if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + body = {} + kv_paris = original_body.split('\n') + for kv in kv_paris: + kv = kv.split(':') + if len(kv) == 2: + body[kv[0]] = kv[1] + elif len(kv) == 1: + body[kv[0]] = '' + else: + raise ValueError(f'Invalid body {kv}') - if node_data.body.type == 'form-data': - self.files = { - k: ('', v) for k, v in body.items() - } + if node_data.body.type == 'form-data': + self.files = { + k: ('', v) for k, v in body.items() + } + else: + self.body = urlencode(body) else: - self.body = urlencode(body) - else: - self.body = original_body + self.body = original_body def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or [] + headers = deepcopy(self.headers) or {} if self.authorization.type == 'api-key': if self.authorization.config.api_key is None: raise ValueError('api_key is required') diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 853f8fe5e3..1ef6f4b66d 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -24,10 +24,12 @@ class HttpRequestNode(BaseNode): # init http executor try: http_executor = HttpExecutor(node_data=node_data, variables=variables) - # invoke http executor + # invoke http executor response = http_executor.invoke() except Exception as e: + import traceback + print(traceback.format_exc()) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py new file mode 100644 index 0000000000..3c2b0cebfc --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -0,0 +1,82 @@ +import os +import pytest +import requests.api as requests +import httpx._api as httpx +from requests import Response as RequestsResponse +from yarl import URL + +from typing import Literal +from _pytest.monkeypatch import MonkeyPatch +from json import dumps + +MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + +class MockedHttp: + def requests_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + url: str, **kwargs) -> RequestsResponse: + """ + Mocked requests.request + """ + response = RequestsResponse() + response.url = str(URL(url) % kwargs.get('params', {})) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response._content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response._content = resp + return response + + def httpx_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + url: str, **kwargs) -> httpx.Response: + """ + Mocked httpx.request + """ + response = httpx.Response() + response.url = str(URL(url) % kwargs.get('params', {})) + response.headers = kwargs.get('headers', {}) + + if url == 'http://404.com': + response.status_code = 404 + response.content = b'Not Found' + return response + + # get data, files + data = kwargs.get('data', None) + files = kwargs.get('files', None) + + if data is not None: + resp = dumps(data).encode('utf-8') + if files is not None: + resp = dumps(files).encode('utf-8') + else: + resp = b'OK' + + response.status_code = 200 + response.content = resp + return response + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + if not MOCK: + yield + return + + monkeypatch.setattr(requests, "request", MockedHttp.requests_request) + monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) + yield + monkeypatch.undo() \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py new file mode 100644 index 0000000000..25c293d563 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -0,0 +1,51 @@ +from calendar import c +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode + +from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock + +BASIC_NODE_DATA = { + 'tenant_id': '1', + 'app_id': '1', + 'workflow_id': '1', + 'user_id': '1', + 'user_from': InvokeFrom.WEB_APP, +} + +# construct variable pool +pool = VariablePool(system_variables={}, user_inputs={}) +pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) +pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_get_param(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': '', + 'params': '', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + print(result) + + assert 1==2 \ No newline at end of file From d3385a2715d8eeeee9d705cc0438283993d07aaa Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 19:51:31 +0800 Subject: [PATCH 165/200] feat --- api/core/helper/ssrf_proxy.py | 1 - .../nodes/http_request/http_executor.py | 19 +- .../nodes/http_request/http_request_node.py | 10 +- .../workflow/nodes/__mock/http.py | 15 +- .../workflow/nodes/test_http.py | 172 +++++++++++++++++- 5 files changed, 197 insertions(+), 20 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 22f5fe57e0..c44d4717e6 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -26,7 +26,6 @@ httpx_proxies = { } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None def get(url, *args, **kwargs): - print(url, kwargs) return _get(url=url, *args, proxies=httpx_proxies, **kwargs) def post(url, *args, **kwargs): diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 6134a7d780..c96d5f07d1 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -43,6 +43,7 @@ class HttpExecutor: self.params = {} self.headers = {} self.body = None + self.files = None # init template self._init_template(node_data, variables) @@ -248,10 +249,24 @@ class HttpExecutor: server_url += f'?{urlencode(self.params)}' raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' - for k, v in self.headers.items(): + + headers = self._assembling_headers() + for k, v in headers.items(): raw_request += f'{k}: {v}\n' raw_request += '\n' - raw_request += self.body or '' + + # if files, use multipart/form-data with boundary + if self.files: + boundary = '----WebKitFormBoundary7MA4YWxkTrZu0gW' + raw_request = f'--{boundary}\n' + raw_request + for k, v in self.files.items(): + raw_request += f'Content-Disposition: form-data; name="{k}"; filename="{v[0]}"\n' + raw_request += f'Content-Type: {v[1]}\n\n' + raw_request += v[1] + '\n' + raw_request += f'--{boundary}\n' + raw_request += '--\n' + else: + raw_request += self.body or '' return raw_request \ No newline at end of file diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 1ef6f4b66d..c83e331fa8 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -28,13 +28,13 @@ class HttpRequestNode(BaseNode): # invoke http executor response = http_executor.invoke() except Exception as e: - import traceback - print(traceback.format_exc()) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), - process_data=http_executor.to_raw_request() + process_data={ + 'request': http_executor.to_raw_request() + } ) return NodeRunResult( @@ -45,7 +45,9 @@ class HttpRequestNode(BaseNode): 'body': response, 'headers': response.headers }, - process_data=http_executor.to_raw_request() + process_data={ + 'request': http_executor.to_raw_request(), + } ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index 3c2b0cebfc..9cc43031f3 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -3,6 +3,7 @@ import pytest import requests.api as requests import httpx._api as httpx from requests import Response as RequestsResponse +from httpx import Request as HttpxRequest from yarl import URL from typing import Literal @@ -12,8 +13,8 @@ from json import dumps MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' class MockedHttp: - def requests_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], - url: str, **kwargs) -> RequestsResponse: + def requests_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, + **kwargs) -> RequestsResponse: """ Mocked requests.request """ @@ -41,13 +42,15 @@ class MockedHttp: response._content = resp return response - def httpx_request(self, method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], + def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, **kwargs) -> httpx.Response: """ Mocked httpx.request """ - response = httpx.Response() - response.url = str(URL(url) % kwargs.get('params', {})) + response = httpx.Response( + status_code=200, + request=HttpxRequest(method, url) + ) response.headers = kwargs.get('headers', {}) if url == 'http://404.com': @@ -67,7 +70,7 @@ class MockedHttp: resp = b'OK' response.status_code = 200 - response.content = resp + response._content = resp return response @pytest.fixture diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 25c293d563..6df8f6b673 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -2,7 +2,6 @@ from calendar import c import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -21,13 +20,16 @@ pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) -def test_get_param(setup_http_mock): +def test_get(setup_http_mock): node = HttpRequestNode(config={ 'id': '1', 'data': { 'title': 'http', 'desc': '', - 'variables': [], + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], 'method': 'get', 'url': 'http://example.com', 'authorization': { @@ -38,14 +40,170 @@ def test_get_param(setup_http_mock): 'header': 'api-key', } }, - 'headers': '', - 'params': '', + 'headers': 'X-Header:123', + 'params': 'A:b', 'body': None, } }, **BASIC_NODE_DATA) result = node.run(pool) - print(result) + data = result.process_data.get('request', '') - assert 1==2 \ No newline at end of file + assert '?A=b' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_template(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'get', + 'url': 'http://example.com/{{args1}}', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123\nX-Header2:{{args1}}', + 'params': 'A:b\nTemplate:{{args1}}', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'Template=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert 'X-Header2: 2' in data + +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_json(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'json', + 'data': '{"a": "{{args1}}"}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert '{"a": "1"}' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_x_www_form_urlencoded(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'x-www-form-urlencoded', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'a=1&b=2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + +def test_form_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'form-data', + 'data': 'a:{{args1}}\nb:{{args2}}' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'form-data; name="a"' in data + assert '1' in data + assert 'form-data; name="b"' in data + assert '2' in data + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data From 513a8655b1009eec73f07c3f9390ab8ef2b60da7 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:31:39 +0800 Subject: [PATCH 166/200] test: tool --- api/core/tools/tool_manager.py | 9 ++- api/core/workflow/nodes/tool/tool_node.py | 11 +-- .../workflow/nodes/test_tool.py | 70 +++++++++++++++++++ 3 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/test_tool.py diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 52e1e71d82..600b54f1c2 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -315,8 +315,9 @@ class ToolManager: for parameter in parameters: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) - runtime_parameters[parameter.name] = value + if parameter.form == ToolParameter.ToolParameterForm.FORM: + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -325,7 +326,9 @@ class ToolManager: provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, ) - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + if runtime_parameters: + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 69a97fc206..c62e025e75 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -29,7 +29,6 @@ class ToolNode(BaseNode): # get parameters parameters = self._generate_parameters(variable_pool, node_data) - # get tool runtime try: tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) @@ -41,7 +40,6 @@ class ToolNode(BaseNode): ) try: - # TODO: user_id messages = tool_runtime.invoke(self.user_id, parameters) except Exception as e: return NodeRunResult( @@ -68,7 +66,7 @@ class ToolNode(BaseNode): return { k.variable: k.value if k.variable_type == 'static' else - variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else '' + variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else '' for k in node_data.tool_parameters } @@ -77,7 +75,12 @@ class ToolNode(BaseNode): Convert ToolInvokeMessages into tuple[plain_text, files] """ # transform message and handle file storage - messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages) + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id='', + ) # extract plain text and files files = self._extract_tool_response_binary(messages) plain_text = self._extract_tool_response_text(messages) diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py new file mode 100644 index 0000000000..72e0d6f853 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -0,0 +1,70 @@ +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.tool.tool_node import ToolNode +from models.workflow import WorkflowNodeExecutionStatus + +""" +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_configurations: dict[str, ToolParameterValue] + +class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(VariableSelector): + variable_type: Literal['selector', 'static'] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + + tool_parameters: list[ToolInput] + +""" + +def test_tool_invoke(): + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') + + node = ToolNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'title': 'a', + 'desc': 'a', + 'provider_id': 'maths', + 'provider_type': 'builtin', + 'provider_name': 'maths', + 'tool_name': 'eval_expression', + 'tool_label': 'eval_expression', + 'tool_configurations': {}, + 'tool_parameters': [ + { + 'variable': 'expression', + 'value_selector': ['1', '123', 'args1'], + 'variable_type': 'selector', + 'value': None + }, + ] + } + } + ) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert '2' in result.outputs['text'] + assert result.outputs['files'] == [] \ No newline at end of file From 2c2b9e738929da9ab06689e37123c5d645b3be87 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:52:49 +0800 Subject: [PATCH 167/200] test: template transform --- .../template_transform_node.py | 9 +++- .../workflow/nodes/__mock/code_executor.py | 4 ++ .../workflow/nodes/test_template_transform.py | 46 +++++++++++++++++++ .../workflow/nodes/test_tool.py | 25 ---------- 4 files changed, 58 insertions(+), 26 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/test_template_transform.py diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index c41f5d1030..15d4b2a6e7 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -7,6 +7,7 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = 1000 class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData @@ -48,7 +49,6 @@ class TemplateTransformNode(BaseNode): ) variables[variable] = value - # Run code try: result = CodeExecutor.execute_code( @@ -62,6 +62,13 @@ class TemplateTransformNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, error=str(e) ) + + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + return NodeRunResult( + inputs=variables, + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index a1c8eb71dc..2eb987181f 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -15,6 +15,10 @@ class MockedCodeExecutor: return { "result": 3 } + elif language == 'jinja2': + return { + "result": "3" + } @pytest.fixture def setup_code_executor_mock(request, monkeypatch: MonkeyPatch): diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py new file mode 100644 index 0000000000..4348995a05 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -0,0 +1,46 @@ +import pytest +from core.app.entities.app_invoke_entities import InvokeFrom + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = '''{{args2}}''' + node = TemplateTransformNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=InvokeFrom.WEB_APP, + config={ + 'id': '1', + 'data': { + 'title': '123', + 'variables': [ + { + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, + { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'] + } + ], + 'template': code, + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={}, user_inputs={}) + pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) + pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=3) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['output'] == '3' diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 72e0d6f853..66139563e2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,34 +1,9 @@ -import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus -""" -class ToolEntity(BaseModel): - provider_id: str - provider_type: Literal['builtin', 'api'] - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, ToolParameterValue] - -class ToolNodeData(BaseNodeData, ToolEntity): - class ToolInput(VariableSelector): - variable_type: Literal['selector', 'static'] - value: Optional[str] - - @validator('value') - def check_value(cls, value, values, **kwargs): - if values['variable_type'] == 'static' and value is None: - raise ValueError('value is required for static variable') - return value - - tool_parameters: list[ToolInput] - -""" - def test_tool_invoke(): pool = VariablePool(system_variables={}, user_inputs={}) pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') From b102562614d6c57db7b5b07efa4c352822b862f5 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 21:58:54 +0800 Subject: [PATCH 168/200] fix: forward-ref --- api/core/workflow/nodes/code/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 0e2b3c99bf..ec3e3fe530 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -12,7 +12,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): type: Literal['string', 'number', 'object', 'array[string]', 'array[number]'] - children: Optional[dict[str, 'CodeNodeData.Output']] + children: Optional[dict[str, 'Output']] variables: list[VariableSelector] answer: str From a420953385f3ebd7fdd08996f5976395b5e8a99b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 22:12:13 +0800 Subject: [PATCH 169/200] feat: docker-compose --- docker/docker-compose.middleware.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index afdabd078a..60604aeaec 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -11,6 +11,9 @@ services: POSTGRES_DB: dify # postgres data directory PGDATA: /var/lib/postgresql/data/pgdata + # The sandbox service endpoint. + CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" + CODE_EXECUTION_API_KEY: dify-sandbox volumes: - ./volumes/db/data:/var/lib/postgresql/data ports: @@ -50,6 +53,16 @@ services: AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' ports: - "8080:8080" + + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + ports: + - "8194:8194" # Qdrant vector store. # uncomment to use qdrant as vector store. From 951aaf5161d6c812745b6953ade0b22ff72cf630 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 22:14:28 +0800 Subject: [PATCH 170/200] feat: sandbox --- docker/docker-compose.middleware.yaml | 3 --- docker/docker-compose.yaml | 13 +++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 60604aeaec..8fba59c315 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -11,9 +11,6 @@ services: POSTGRES_DB: dify # postgres data directory PGDATA: /var/lib/postgresql/data/pgdata - # The sandbox service endpoint. - CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" - CODE_EXECUTION_API_KEY: dify-sandbox volumes: - ./volumes/db/data:/var/lib/postgresql/data ports: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d627bb3848..ca6b6cbf1a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -122,6 +122,9 @@ services: SENTRY_TRACES_SAMPLE_RATE: 1.0 # The sample rate for Sentry profiles. Default: `1.0` SENTRY_PROFILES_SAMPLE_RATE: 1.0 + # The sandbox service endpoint. + CODE_EXECUTION_ENDPOINT: "http://sandbox:8194" + CODE_EXECUTION_API_KEY: dify-sandbox depends_on: - db - redis @@ -286,6 +289,16 @@ services: # ports: # - "8080:8080" + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:latest + restart: always + environment: + # The DifySandbox configurations + API_KEY: dify-sandbox + ports: + - "8194:8194" + # Qdrant vector store. # uncomment to use qdrant as vector store. # (if uncommented, you need to comment out the weaviate service above, From 92c1da8dbeb92310bb07c7507aee2420c4cd179e Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 16:25:07 +0800 Subject: [PATCH 171/200] fix: remove answer --- api/core/workflow/nodes/code/entities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index ec3e3fe530..d4d76c45f9 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -15,7 +15,6 @@ class CodeNodeData(BaseNodeData): children: Optional[dict[str, 'Output']] variables: list[VariableSelector] - answer: str code_language: Literal['python3', 'javascript'] code: str outputs: dict[str, Output] From e8751bebfa1b8b05ae6cf1274a4457075f51de07 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 19:15:11 +0800 Subject: [PATCH 172/200] fix single step run error --- api/services/workflow_service.py | 64 +++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2c9c07106c..55f2526fbf 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -270,28 +270,48 @@ class WorkflowService: return workflow_node_execution - # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() - ) + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, + process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, + outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, + execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) + if node_run_result.metadata else None), + status=WorkflowNodeExecutionStatus.SUCCEEDED.value, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) + else: + # create workflow node execution + workflow_node_execution = WorkflowNodeExecution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=draft_workflow.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, + index=1, + node_id=node_id, + node_type=node_instance.node_type.value, + title=node_instance.node_data.title, + status=node_run_result.status.value, + error=node_run_result.error, + elapsed_time=time.perf_counter() - start_at, + created_by_role=CreatedByRole.ACCOUNT.value, + created_by=account.id, + created_at=datetime.utcnow(), + finished_at=datetime.utcnow() + ) db.session.add(workflow_node_execution) db.session.commit() From d88ac6c238412984e37967e51219e553f12bc254 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 22:12:03 +0800 Subject: [PATCH 173/200] add llm node --- api/core/app/apps/base_app_runner.py | 31 +- .../easy_ui_based_generate_task_pipeline.py | 83 +--- api/core/model_manager.py | 4 +- api/core/prompt/advanced_prompt_transform.py | 51 ++- .../entities}/__init__.py | 0 .../entities/advanced_prompt_entities.py | 42 ++ api/core/prompt/prompt_transform.py | 19 +- api/core/prompt/simple_prompt_transform.py | 11 + api/core/prompt/utils/prompt_message_util.py | 85 ++++ api/core/workflow/entities/node_entities.py | 2 +- api/core/workflow/nodes/answer/__init__.py | 0 .../answer_node.py} | 8 +- .../{direct_answer => answer}/entities.py | 4 +- api/core/workflow/nodes/llm/entities.py | 45 ++- api/core/workflow/nodes/llm/llm_node.py | 370 +++++++++++++++++- api/core/workflow/workflow_engine_manager.py | 47 +-- .../prompt/test_advanced_prompt_transform.py | 77 ++-- 17 files changed, 697 insertions(+), 182 deletions(-) rename api/core/{workflow/nodes/direct_answer => prompt/entities}/__init__.py (100%) create mode 100644 api/core/prompt/entities/advanced_prompt_entities.py create mode 100644 api/core/prompt/utils/prompt_message_util.py create mode 100644 api/core/workflow/nodes/answer/__init__.py rename api/core/workflow/nodes/{direct_answer/direct_answer_node.py => answer/answer_node.py} (91%) rename api/core/workflow/nodes/{direct_answer => answer}/entities.py (75%) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e7ce7f25ef..868e9e724f 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation @@ -155,13 +156,39 @@ class AppRunner: model_config=model_config ) else: + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False + ) + ) + + model_mode = ModelMode.value_of(model_config.mode) + if model_mode == ModelMode.COMPLETION: + advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + prompt_template = CompletionModelPromptTemplate( + text=advanced_completion_prompt_template.prompt + ) + + memory_config.role_prefix = MemoryConfig.RolePrefix( + user=advanced_completion_prompt_template.role_prefix.user, + assistant=advanced_completion_prompt_template.role_prefix.assistant + ) + else: + prompt_template = [] + for message in prompt_template_entity.advanced_chat_prompt_template.messages: + prompt_template.append(ChatModelMessage( + text=message.text, + role=message.role + )) + prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query if query else '', files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) diff --git a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py index 856bfb623d..412029b024 100644 --- a/api/core/app/apps/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/apps/easy_ui_based_generate_task_pipeline.py @@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.moderation.output_moderation import ModerationRule, OutputModeration -from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager from events.message_event import message_was_created @@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline: self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() - self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) + self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + self._model_config.mode, + self._task_state.llm_result.prompt_messages + ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline: """ return "data: " + json.dumps(response) + "\n\n" - def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: - """ - Prompt messages to prompt for saving. - :param prompt_messages: prompt messages - :return: - """ - prompts = [] - if self._model_config.mode == ModelMode.CHAT.value: - for prompt_message in prompt_messages: - if prompt_message.role == PromptMessageRole.USER: - role = 'user' - elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' - elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' - else: - continue - - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - prompts.append({ - "role": role, - "text": text, - "files": files - }) - else: - prompt_message = prompt_messages[0] - text = '' - files = [] - if isinstance(prompt_message.content, list): - for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) - text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) - else: - text = prompt_message.content - - params = { - "role": 'user', - "text": text, - } - - if files: - params['files'] = files - - prompts.append(params) - - return prompts - def _init_output_moderation(self) -> Optional[OutputModeration]: """ Init output moderation. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index aa16cf866f..8c06339927 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -24,11 +24,11 @@ class ModelInstance: """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: - self._provider_model_bundle = provider_model_bundle + self.provider_model_bundle = provider_model_bundle self.model = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) - self.model_type_instance = self._provider_model_bundle.model_type_instance + self.model_type_instance = self.provider_model_bundle.model_type_instance def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 48b0d8ba02..60c77e943b 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,5 @@ -from typing import Optional +from typing import Optional, Union -from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def get_prompt(self, prompt_template_entity: PromptTemplateEntity, + def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], inputs: dict, query: str, files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: prompt_messages = [] @@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform): model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: prompt_messages = self._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config ) @@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages def _get_completion_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: CompletionModelPromptTemplate, inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get completion model prompt messages. """ - raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt + raw_prompt = prompt_template.text prompt_messages = [] @@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform): prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - prompt_inputs = self._set_histories_variable( - memory=memory, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config - ) + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) @@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages def _get_chat_model_prompt_messages(self, - prompt_template_entity: PromptTemplateEntity, + prompt_template: list[ChatModelMessage], inputs: dict, query: Optional[str], files: list[FileObj], context: Optional[str], + memory_config: Optional[MemoryConfig], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages + raw_prompt_list = prompt_template prompt_messages = [] @@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform): elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if memory: - prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) + if memory and memory_config: + prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform): return prompt_inputs def _set_histories_variable(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, raw_prompt: str, - role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, + role_prefix: MemoryConfig.RolePrefix, prompt_template: PromptTemplateParser, prompt_inputs: dict, model_config: ModelConfigWithCredentialsEntity) -> dict: @@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform): histories = self._get_history_messages_from_memory( memory=memory, + memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, ai_prefix=role_prefix.assistant diff --git a/api/core/workflow/nodes/direct_answer/__init__.py b/api/core/prompt/entities/__init__.py similarity index 100% rename from api/core/workflow/nodes/direct_answer/__init__.py rename to api/core/prompt/entities/__init__.py diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py new file mode 100644 index 0000000000..97ac2e3e2a --- /dev/null +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -0,0 +1,42 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """ + Chat Message. + """ + text: str + role: PromptMessageRole + + +class CompletionModelPromptTemplate(BaseModel): + """ + Completion Model Prompt Template. + """ + text: str + + +class MemoryConfig(BaseModel): + """ + Memory Config. + """ + class RolePrefix(BaseModel): + """ + Role Prefix. + """ + user: str + assistant: str + + class WindowConfig(BaseModel): + """ + Window Config. + """ + enabled: bool + size: Optional[int] = None + + role_prefix: Optional[RolePrefix] = None + window: WindowConfig diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 02e91d9112..9bf2ae090f 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -44,6 +47,7 @@ class PromptTransform: return rest_tokens def _get_history_messages_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int, human_prefix: Optional[str] = None, ai_prefix: Optional[str] = None) -> str: @@ -58,13 +62,22 @@ class PromptTransform: if ai_prefix: kwargs['ai_prefix'] = ai_prefix + if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: + kwargs['message_limit'] = memory_config.window.size + return memory.get_history_prompt_text( **kwargs ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, + memory_config: MemoryConfig, max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( - max_token_limit=max_token_limit + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if (memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0) + else 10 ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index ca0efb200c..613716c2cf 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode @@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform): if memory: prompt_messages = self._append_chat_histories( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), prompt_messages=prompt_messages, model_config=model_config ) @@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform): rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( memory=memory, + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False, + ) + ), max_token_limit=rest_tokens, ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py new file mode 100644 index 0000000000..5fceeb3595 --- /dev/null +++ b/api/core/prompt/utils/prompt_message_util.py @@ -0,0 +1,85 @@ +from typing import cast + +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) +from core.prompt.simple_prompt_transform import ModelMode + + +class PromptMessageUtil: + @staticmethod + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + """ + Prompt messages to prompt for saving. + :param model_mode: model mode + :param prompt_messages: prompt messages + :return: + """ + prompts = [] + if model_mode == ModelMode.CHAT.value: + for prompt_message in prompt_messages: + if prompt_message.role == PromptMessageRole.USER: + role = 'user' + elif prompt_message.role == PromptMessageRole.ASSISTANT: + role = 'assistant' + elif prompt_message.role == PromptMessageRole.SYSTEM: + role = 'system' + else: + continue + + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + prompts.append({ + "role": role, + "text": text, + "files": files + }) + else: + prompt_message = prompt_messages[0] + text = '' + files = [] + if isinstance(prompt_message.content, list): + for content in prompt_message.content: + if content.type == PromptMessageContentType.TEXT: + content = cast(TextPromptMessageContent, content) + text += content.data + else: + content = cast(ImagePromptMessageContent, content) + files.append({ + "type": 'image', + "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], + "detail": content.detail.value + }) + else: + text = prompt_message.content + + params = { + "role": 'user', + "text": text, + } + + if files: + params['files'] = files + + prompts.append(params) + + return prompts diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 263172da31..befabfb3b4 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,7 +12,7 @@ class NodeType(Enum): """ START = 'start' END = 'end' - DIRECT_ANSWER = 'direct-answer' + ANSWER = 'answer' LLM = 'llm' KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' IF_ELSE = 'if-else' diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/direct_answer/direct_answer_node.py b/api/core/workflow/nodes/answer/answer_node.py similarity index 91% rename from api/core/workflow/nodes/direct_answer/direct_answer_node.py rename to api/core/workflow/nodes/answer/answer_node.py index 22ef2ed53b..381ada1a1e 100644 --- a/api/core/workflow/nodes/direct_answer/direct_answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.nodes.answer.entities import AnswerNodeData from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData from models.workflow import WorkflowNodeExecutionStatus -class DirectAnswerNode(BaseNode): - _node_data_cls = DirectAnswerNodeData - node_type = NodeType.DIRECT_ANSWER +class AnswerNode(BaseNode): + _node_data_cls = AnswerNodeData + node_type = NodeType.ANSWER def _run(self, variable_pool: VariablePool) -> NodeRunResult: """ diff --git a/api/core/workflow/nodes/direct_answer/entities.py b/api/core/workflow/nodes/answer/entities.py similarity index 75% rename from api/core/workflow/nodes/direct_answer/entities.py rename to api/core/workflow/nodes/answer/entities.py index e7c11e3c4d..7c6fed3e4e 100644 --- a/api/core/workflow/nodes/direct_answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector -class DirectAnswerNodeData(BaseNodeData): +class AnswerNodeData(BaseNodeData): """ - DirectAnswer Node Data. + Answer Node Data. """ variables: list[VariableSelector] = [] answer: str diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bd499543d9..67163c93cd 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,8 +1,51 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class ContextConfig(BaseModel): + """ + Context Config. + """ + enabled: bool + variable_selector: Optional[list[str]] = None + + +class VisionConfig(BaseModel): + """ + Vision Config. + """ + class Configs(BaseModel): + """ + Configs. + """ + detail: Literal['low', 'high'] + + enabled: bool + configs: Optional[Configs] = None class LLMNodeData(BaseNodeData): """ LLM Node Data. """ - pass + model: ModelConfig + variables: list[VariableSelector] = [] + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + memory: Optional[MemoryConfig] = None + context: ContextConfig + vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 41e28937ac..d1050a5f5b 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,10 +1,27 @@ +from collections.abc import Generator from typing import Optional, cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file.file_obj import FileObj +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData +from extensions.ext_database import db +from models.model import Conversation +from models.workflow import WorkflowNodeExecutionStatus class LLMNode(BaseNode): @@ -20,7 +37,341 @@ class LLMNode(BaseNode): node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - pass + node_inputs = None + process_data = None + + try: + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data, variable_pool) + + node_inputs = { + **inputs + } + + # fetch files + files: list[FileObj] = self._fetch_files(node_data, variable_pool) + + if files: + node_inputs['#files#'] = [{ + 'type': file.type.value, + 'transfer_method': file.transfer_method.value, + 'url': file.url, + 'upload_file_id': file.upload_file_id, + } for file in files] + + # fetch context value + context = self._fetch_context(node_data, variable_pool) + + if context: + node_inputs['#context#'] = context + + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + + # fetch memory + memory = self._fetch_memory(node_data, variable_pool, model_instance) + + # fetch prompt messages + prompt_messages, stop = self._fetch_prompt_messages( + node_data=node_data, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config + ) + + process_data = { + 'model_mode': model_config.mode, + 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, + prompt_messages=prompt_messages + ) + } + + # handle invoke result + result_text, usage = self._invoke_llm( + node_data=node_data, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop + ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) + + outputs = { + 'text': result_text, + 'usage': jsonable_encoder(usage) + } + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) + + def _invoke_llm(self, node_data: LLMNodeData, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str]) -> tuple[str, LLMUsage]: + """ + Invoke large language model + :param node_data: node data + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: + """ + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data.model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + return self._handle_invoke_result( + invoke_result=invoke_result + ) + + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + """ + Handle invoke result + :param invoke_result: invoke result + :return: + """ + model = None + prompt_messages = [] + full_text = '' + usage = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + self.publish_text_chunk(text=text) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not usage: + usage = LLMUsage.empty_usage() + + return full_text, usage + + def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + inputs = {} + for variable_selector in node_data.variables: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.value_selector} not found') + + inputs[variable_selector.variable] = variable_value + + return inputs + + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]: + """ + Fetch files + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.vision.enabled: + return [] + + files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + if not files: + return [] + + return files + + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + """ + Fetch context + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.context.enabled: + return None + + context_value = variable_pool.get_variable_value(node_data.context.variable_selector) + if context_value: + if isinstance(context_value, str): + return context_value + elif isinstance(context_value, list): + context_str = '' + for item in context_value: + if 'content' not in item: + raise ValueError(f'Invalid context structure: {item}') + + context_str += item['content'] + '\n' + + return context_str.strip() + + return None + + def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.model.name + provider_name = node_data.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory(self, node_data: LLMNodeData, + variable_pool: VariablePool, + model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + """ + Fetch memory + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + if not node_data.memory: + return None + + # get conversation id + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION]) + if conversation_id is None: + return None + + # get conversation + conversation = db.session.query(Conversation).filter( + Conversation.tenant_id == self.tenant_id, + Conversation.app_id == self.app_id, + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + return memory + + def _fetch_prompt_messages(self, node_data: LLMNodeData, + inputs: dict[str, str], + files: list[FileObj], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity) \ + -> tuple[list[PromptMessage], Optional[list[str]]]: + """ + Fetch prompt messages + :param node_data: node data + :param inputs: inputs + :param files: files + :param context: context + :param memory: memory + :param model_config: model config + :return: + """ + prompt_transform = AdvancedPromptTransform() + prompt_messages = prompt_transform.get_prompt( + prompt_template=node_data.prompt_template, + inputs=inputs, + query='', + files=files, + context=context, + memory_config=node_data.memory, + memory=memory, + model_config=model_config + ) + stop = model_config.stop + + return prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: @@ -29,9 +380,20 @@ class LLMNode(BaseNode): :param node_data: node data :return: """ - # TODO extract variable selector to variable mapping for single step debugging - return {} + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + variable_mapping = {} + for variable_selector in node_data.variables: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping['#context#'] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + + return variable_mapping @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 17225c19ea..49b9d4ac4d 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode @@ -24,13 +24,12 @@ from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, - WorkflowType, ) node_classes = { NodeType.START: StartNode, NodeType.END: EndNode, - NodeType.DIRECT_ANSWER: DirectAnswerNode, + NodeType.ANSWER: AnswerNode, NodeType.LLM: LLMNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.IF_ELSE: IfElseNode, @@ -156,7 +155,7 @@ class WorkflowEngineManager: callbacks=callbacks ) - if next_node.node_type == NodeType.END: + if next_node.node_type in [NodeType.END, NodeType.ANSWER]: break predecessor_node = next_node @@ -402,10 +401,16 @@ class WorkflowEngineManager: # add to workflow_nodes_and_results workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) + try: + # run node, result must have inputs, process_data, outputs, execution_metadata + node_run_result = node.run( + variable_pool=workflow_run_state.variable_pool + ) + except Exception as e: + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e) + ) if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: # node run failed @@ -420,9 +425,6 @@ class WorkflowEngineManager: raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - # set end node output if in chat - self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result) - workflow_nodes_and_result.result = node_run_result # node run success @@ -453,29 +455,6 @@ class WorkflowEngineManager: db.session.close() - def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState, - node: BaseNode, - node_run_result: NodeRunResult) -> None: - """ - Set end node output if in chat - :param workflow_run_state: workflow run state - :param node: current node - :param node_run_result: node run result - :return: - """ - if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: - workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] - if workflow_nodes_and_result_before_end: - if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text') - elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER: - if not node_run_result.outputs: - node_run_result.outputs = {} - - node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer') def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 4357c6405c..5c08b9f168 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ - ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity +from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages(): model_config_mock.model = 'gpt-3.5-turbo-instruct' prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt=prompt_template, - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( - user="Human", - assistant="Assistant" - ) + prompt_template_config = CompletionModelPromptTemplate( + text=prompt_template + ) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix( + user="Human", + assistant="Assistant" + ), + window=MemoryConfig.WindowConfig( + enabled=False ) ) + inputs = { "name": "John" } @@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages(): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_completion_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=prompt_template_config, inputs=inputs, query=None, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages(): def test__get_chat_model_prompt_messages(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, memory_config, messages, inputs, context = get_chat_model_args files = [] query = "Hi2." @@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=query, files=files, context=context, + memory_config=memory_config, memory=memory, model_config=model_config_mock ) @@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): assert len(prompt_messages) == 6 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert prompt_messages[5].content == query def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): assert len(prompt_messages) == 3 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): - model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + model_config_mock, _, messages, inputs, context = get_chat_model_args files = [ FileObj( @@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_messages = prompt_transform._get_chat_model_prompt_messages( - prompt_template_entity=prompt_template_entity, + prompt_template=messages, inputs=inputs, query=None, files=files, context=context, + memory_config=None, memory=None, model_config=model_config_mock ) @@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg assert len(prompt_messages) == 4 assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].content == PromptTemplateParser( - template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + template=messages[0].text ).format({**inputs, "#context#": context}) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 @@ -173,22 +181,31 @@ def get_chat_model_args(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' - prompt_template_entity = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig( + enabled=False ) ) + prompt_messages = [ + ChatModelMessage( + text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM + ), + ChatModelMessage( + text="Hi.", + role=PromptMessageRole.USER + ), + ChatModelMessage( + text="Hello!", + role=PromptMessageRole.ASSISTANT + ) + ] + inputs = { "name": "John" } context = "I am superman." - return model_config_mock, prompt_template_entity, inputs, context + return model_config_mock, memory_config, prompt_messages, inputs, context From 2182533af830181f6b88b6c2fa89fa6ed44a91e4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:41:59 +0800 Subject: [PATCH 174/200] feat: javascript code --- api/.env.example | 2 +- .../helper/code_executor/code_executor.py | 8 ++- .../code_executor/javascript_transformer.py | 54 ++++++++++++++++++- api/core/workflow/nodes/code/code_node.py | 17 ++++-- api/core/workflow/nodes/code/entities.py | 2 +- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/api/.env.example b/api/.env.example index 4a3b1d65af..c0942412ab 100644 --- a/api/.env.example +++ b/api/.env.example @@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTINO_API_KEY= +CODE_EXECUTION_API_KEY= diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 21a8ca5f9f..adfdf6cc69 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from typing import Literal, Optional from httpx import post from pydantic import BaseModel from yarl import URL +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer @@ -39,17 +40,20 @@ class CodeExecutor: template_transformer = PythonTemplateTransformer elif language == 'jinja2': template_transformer = Jinja2TemplateTransformer + elif language == 'javascript': + template_transformer = NodeJsTemplateTransformer else: raise CodeExecutionException('Unsupported language') runner = template_transformer.transform_caller(code, inputs) - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' headers = { 'X-Api-Key': CODE_EXECUTION_API_KEY } data = { - 'language': language if language != 'jinja2' else 'python3', + 'language': 'python3' if language == 'jinja2' else + 'nodejs' if language == 'javascript' else + 'python3' if language == 'python3' else None, 'code': runner, } diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py index f87f5c14cb..cc6ad16c66 100644 --- a/api/core/helper/code_executor/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -1 +1,53 @@ -# TODO \ No newline at end of file +import json +import re + +from core.helper.code_executor.template_transformer import TemplateTransformer + +NODEJS_RUNNER = """// declare main function here +{{code}} + +// execute main function, and return the result +// inputs is a dict, unstructured inputs +output = main({{inputs}}) + +// convert output to json and print +output = JSON.stringify(output) + +result = `<>${output}<>` + +console.log(result) +""" + + +class NodeJsTemplateTransformer(TemplateTransformer): + @classmethod + def transform_caller(cls, code: str, inputs: dict) -> str: + """ + Transform code to python runner + :param code: code + :param inputs: inputs + :return: + """ + + # transform inputs to json string + inputs_str = json.dumps(inputs, indent=4) + + # replace code and inputs + runner = NODEJS_RUNNER.replace('{{code}}', code) + runner = runner.replace('{{inputs}}', inputs_str) + + return runner + + @classmethod + def transform_response(cls, response: str) -> dict: + """ + Transform response to dict + :param response: response + :return: + """ + # extract result + result = re.search(r'<>(.*)<>', response, re.DOTALL) + if not result: + raise ValueError('Failed to parse result') + result = result.group(1) + return json.loads(result) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 2c11e5ba00..5dfe398711 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -15,6 +15,16 @@ MAX_STRING_LENGTH = 1000 MAX_STRING_ARRAY_LENGTH = 30 MAX_NUMBER_ARRAY_LENGTH = 1000 +JAVASCRIPT_DEFAULT_CODE = """function main({args1, args2}) { + return { + result: args1 + args2 + } +}""" + +PYTHON_DEFAULT_CODE = """def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + }""" class CodeNode(BaseNode): _node_data_cls = CodeNodeData @@ -42,9 +52,7 @@ class CodeNode(BaseNode): } ], "code_language": "javascript", - "code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" - "\n if (true) {\n resolve({\n \"result\": arg1 + arg2" - "\n });\n } else {\n reject(\"e\");\n }\n });\n}", + "code": JAVASCRIPT_DEFAULT_CODE, "outputs": [ { "variable": "result", @@ -68,8 +76,7 @@ class CodeNode(BaseNode): } ], "code_language": "python3", - "code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " - "+ arg2\n }", + "code": PYTHON_DEFAULT_CODE, "outputs": [ { "variable": "result", diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index d4d76c45f9..97e178f5df 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -17,4 +17,4 @@ class CodeNodeData(BaseNodeData): variables: list[VariableSelector] code_language: Literal['python3', 'javascript'] code: str - outputs: dict[str, Output] + outputs: dict[str, Output] \ No newline at end of file From e6572ef2d76b3cce21ca6a41be1c4c824a63a1d9 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 12 Mar 2024 22:42:28 +0800 Subject: [PATCH 175/200] fix: linter --- api/core/helper/code_executor/code_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index adfdf6cc69..9d74edee0e 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,8 +4,8 @@ from typing import Literal, Optional from httpx import post from pydantic import BaseModel from yarl import URL -from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer +from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer From e4794e309a94b992e4504ea93f76196cd04127ad Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 12 Mar 2024 23:08:14 +0800 Subject: [PATCH 176/200] add llm node test --- .../workflow/nodes/__init__.py | 0 .../workflow/nodes/test_llm.py | 132 ++++++++++++++++++ .../workflow/nodes/test_template_transform.py | 4 +- .../core/workflow/nodes/__init__.py | 0 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 api/tests/integration_tests/workflow/nodes/__init__.py create mode 100644 api/tests/integration_tests/workflow/nodes/test_llm.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/__init__.py diff --git a/api/tests/integration_tests/workflow/nodes/__init__.py b/api/tests/integration_tests/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py new file mode 100644 index 0000000000..18fba566bf --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -0,0 +1,132 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration +from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import ModelProviderFactory +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.llm.llm_node import LLMNode +from extensions.ext_database import db +from models.provider import ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm(setup_openai_mock): + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5.turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'variables': [ + { + 'variable': 'weather', + 'value_selector': ['abc', 'output'], + }, + { + 'variable': 'query', + 'value_selector': ['sys', 'query'] + } + ], + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.' + }, + { + 'role': 'user', + 'text': '{{query}}' + } + ], + 'memory': { + 'window': { + 'enabled': True, + 'size': 2 + } + }, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION: 'abababa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['text'] is not None + assert result.outputs['usage']['total_tokens'] > 0 diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 4348995a05..36cf0a070a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,7 +1,7 @@ import pytest -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock): app_id='1', workflow_id='1', user_id='1', - user_from=InvokeFrom.WEB_APP, + user_from=UserFrom.END_USER, config={ 'id': '1', 'data': { diff --git a/api/tests/unit_tests/core/workflow/nodes/__init__.py b/api/tests/unit_tests/core/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From da3e1e9d14a2b6aa102709898d0469a5962bdb9d Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 00:08:13 +0800 Subject: [PATCH 177/200] add deduct quota for llm node --- api/core/workflow/nodes/llm/llm_node.py | 56 ++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index d1050a5f5b..9285bbe74e 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -3,6 +3,7 @@ from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -21,6 +22,7 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData from extensions.ext_database import db from models.model import Conversation +from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus @@ -144,10 +146,15 @@ class LLMNode(BaseNode): ) # handle invoke result - return self._handle_invoke_result( + text, usage = self._handle_invoke_result( invoke_result=invoke_result ) + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ Handle invoke result @@ -373,6 +380,53 @@ class LLMNode(BaseNode): return prompt_messages, stop + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ From 2b4b6817a3d082c8a4421918b2aef672771bad2f Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 14:55:56 +0800 Subject: [PATCH 178/200] record inputs and process data when node failed --- .../workflow_event_trigger_callback.py | 6 +++++- .../workflow_event_trigger_callback.py | 6 +++++- api/core/app/entities/queue_entities.py | 3 +++ .../callbacks/base_workflow_callback.py | 4 +++- api/core/workflow/workflow_engine_manager.py | 4 +++- api/models/workflow.py | 18 +++++++++--------- .../workflow/nodes/test_llm.py | 2 +- 7 files changed, 29 insertions(+), 14 deletions(-) diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index d9c8a2c96d..b4a6a9602f 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -96,7 +96,9 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ @@ -105,6 +107,8 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): node_id=node_id, node_type=node_type, node_data=node_data, + inputs=inputs, + process_data=process_data, error=error ), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 318466711a..ea7eb5688c 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -96,7 +96,9 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ @@ -105,6 +107,8 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): node_id=node_id, node_type=node_type, node_data=node_data, + inputs=inputs, + process_data=process_data, error=error ), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 0ea7744b58..153607e1b4 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -158,6 +158,9 @@ class QueueNodeFailedEvent(AppQueueEvent): node_type: NodeType node_data: BaseNodeData + inputs: Optional[dict] = None + process_data: Optional[dict] = None + error: str diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index cf2915ed86..9594fa2037 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -55,7 +55,9 @@ class BaseWorkflowCallback(ABC): def on_workflow_node_execute_failed(self, node_id: str, node_type: NodeType, node_data: BaseNodeData, - error: str) -> None: + error: str, + inputs: Optional[dict] = None, + process_data: Optional[dict] = None) -> None: """ Workflow node execute failed """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 49b9d4ac4d..ebc753537e 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -420,7 +420,9 @@ class WorkflowEngineManager: node_id=node.node_id, node_type=node.node_type, node_data=node.node_data, - error=node_run_result.error + error=node_run_result.error, + inputs=node_run_result.inputs, + process_data=node_run_result.process_data, ) raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") diff --git a/api/models/workflow.py b/api/models/workflow.py index 5a3cdcf83c..9c5b2a0b8f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -123,11 +123,11 @@ class Workflow(db.Model): @property def graph_dict(self): - return self.graph if not self.graph else json.loads(self.graph) + return json.loads(self.graph) if self.graph else None @property def features_dict(self): - return self.features if not self.features else json.loads(self.features) + return json.loads(self.features) if self.features else None def user_input_form(self) -> list: # get start node from graph @@ -270,15 +270,15 @@ class WorkflowRun(db.Model): @property def graph_dict(self): - return self.graph if not self.graph else json.loads(self.graph) + return json.loads(self.graph) if self.graph else None @property def inputs_dict(self): - return self.inputs if not self.inputs else json.loads(self.inputs) + return json.loads(self.inputs) if self.inputs else None @property def outputs_dict(self): - return self.outputs if not self.outputs else json.loads(self.outputs) + return json.loads(self.outputs) if self.outputs else None class WorkflowNodeExecutionTriggeredFrom(Enum): @@ -419,19 +419,19 @@ class WorkflowNodeExecution(db.Model): @property def inputs_dict(self): - return self.inputs if not self.inputs else json.loads(self.inputs) + return json.loads(self.inputs) if self.inputs else None @property def outputs_dict(self): - return self.outputs if not self.outputs else json.loads(self.outputs) + return json.loads(self.outputs) if self.outputs else None @property def process_data_dict(self): - return self.process_data if not self.process_data else json.loads(self.process_data) + return json.loads(self.process_data) if self.process_data else None @property def execution_metadata_dict(self): - return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata) + return json.loads(self.execution_metadata) if self.execution_metadata else None class WorkflowAppLogCreatedFrom(Enum): diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 18fba566bf..999ebf7734 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -36,7 +36,7 @@ def test_execute_llm(setup_openai_mock): 'type': 'llm', 'model': { 'provider': 'openai', - 'name': 'gpt-3.5.turbo', + 'name': 'gpt-3.5-turbo', 'mode': 'chat', 'completion_params': {} }, From 5213b0aade7efd50e1df43b055822db49bbbc71c Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 15:01:02 +0800 Subject: [PATCH 179/200] add sequence_number for workflow_started event --- api/core/app/apps/advanced_chat/generate_task_pipeline.py | 1 + api/core/app/apps/workflow/generate_task_pipeline.py | 1 + 2 files changed, 2 insertions(+) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index d5d3feded0..e8463e59d3 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -226,6 +226,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, 'created_at': int(workflow_run.created_at.timestamp()) } } diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 7a244151f2..cd1ea4c81e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -195,6 +195,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): 'data': { 'id': workflow_run.id, 'workflow_id': workflow_run.workflow_id, + 'sequence_number': workflow_run.sequence_number, 'created_at': int(workflow_run.created_at.timestamp()) } } From 7e53625eae2fd41ae739e3c7e121555f7a846526 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 15:08:15 +0800 Subject: [PATCH 180/200] fix value type --- api/core/workflow/entities/variable_pool.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 3868041a8f..7a5f58d808 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -13,7 +13,10 @@ class ValueType(Enum): STRING = "string" NUMBER = "number" OBJECT = "object" - ARRAY = "array" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" FILE = "file" @@ -78,7 +81,10 @@ class VariablePool: elif target_value_type == ValueType.OBJECT: if not isinstance(value, dict): raise ValueError('Invalid value type: object') - elif target_value_type == ValueType.ARRAY: + elif target_value_type in [ValueType.ARRAY_STRING, + ValueType.ARRAY_NUMBER, + ValueType.ARRAY_OBJECT, + ValueType.ARRAY_FILE]: if not isinstance(value, list): raise ValueError('Invalid value type: array') From 735b55e61b0751cf5ab75974b0f146474c9c575a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 17:10:51 +0800 Subject: [PATCH 181/200] add if-else node --- api/core/workflow/entities/variable_pool.py | 2 +- api/core/workflow/nodes/if_else/entities.py | 26 ++ .../workflow/nodes/if_else/if_else_node.py | 395 +++++++++++++++++- .../core/workflow/nodes/if_else_node.py | 193 +++++++++ 4 files changed, 614 insertions(+), 2 deletions(-) create mode 100644 api/core/workflow/nodes/if_else/entities.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/if_else_node.py diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 7a5f58d808..ff96bc3bac 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -86,6 +86,6 @@ class VariablePool: ValueType.ARRAY_OBJECT, ValueType.ARRAY_FILE]: if not isinstance(value, list): - raise ValueError('Invalid value type: array') + raise ValueError(f'Invalid value type: {target_value_type.value}') return value diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py new file mode 100644 index 0000000000..68d51c93be --- /dev/null +++ b/api/core/workflow/nodes/if_else/entities.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class IfElseNodeData(BaseNodeData): + """ + Answer Node Data. + """ + class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None + + logical_operator: Literal["and", "or"] = "and" + conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 98a5c85db2..9cb084b116 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,398 @@ +from typing import Optional, cast + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.if_else.entities import IfElseNodeData +from models.workflow import WorkflowNodeExecutionStatus class IfElseNode(BaseNode): - pass + _node_data_cls = IfElseNodeData + node_type = NodeType.IF_ELSE + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run node + :param variable_pool: variable pool + :return: + """ + node_data = self.node_data + node_data = cast(self._node_data_cls, node_data) + + node_inputs = { + "conditions": [] + } + + process_datas = { + "condition_results": [] + } + + try: + logical_operator = node_data.logical_operator + input_conditions = [] + for condition in node_data.conditions: + actual_value = variable_pool.get_variable_value( + variable_selector=condition.variable_selector + ) + + expected_value = condition.value + + input_conditions.append({ + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator + }) + + node_inputs["conditions"] = input_conditions + + for input_condition in input_conditions: + actual_value = input_condition["actual_value"] + expected_value = input_condition["expected_value"] + comparison_operator = input_condition["comparison_operator"] + + if comparison_operator == "contains": + compare_result = self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + compare_result = self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + compare_result = self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + compare_result = self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + compare_result = self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + compare_result = self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + compare_result = self._assert_empty(actual_value) + elif comparison_operator == "not empty": + compare_result = self._assert_not_empty(actual_value) + elif comparison_operator == "=": + compare_result = self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + compare_result = self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + compare_result = self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + compare_result = self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + compare_result = self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + compare_result = self._assert_null(actual_value) + elif comparison_operator == "not null": + compare_result = self._assert_not_null(actual_value) + else: + continue + + process_datas["condition_results"].append({ + **input_condition, + "result": compare_result + }) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=node_inputs, + process_datas=process_datas, + error=str(e) + ) + + if logical_operator == "and": + compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]] + else: + compare_result = True in [condition["result"] for condition in process_datas["condition_results"]] + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_datas=process_datas, + edge_source_handle="false" if not compare_result else "true", + outputs={ + "result": compare_result + } + ) + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/tests/unit_tests/core/workflow/nodes/if_else_node.py b/api/tests/unit_tests/core/workflow/nodes/if_else_node.py new file mode 100644 index 0000000000..7b402ad0a0 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/if_else_node.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_if_else_result_true(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'and', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'not_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'start with', + 'variable_selector': ['start', 'start_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'end with', + 'variable_selector': ['start', 'end_with'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is', + 'variable_selector': ['start', 'is'], + 'value': 'ab' + }, + { + 'comparison_operator': 'is not', + 'variable_selector': ['start', 'is_not'], + 'value': 'ab' + }, + { + 'comparison_operator': 'empty', + 'variable_selector': ['start', 'empty'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not empty', + 'variable_selector': ['start', 'not_empty'], + 'value': 'ab' + }, + { + 'comparison_operator': '=', + 'variable_selector': ['start', 'equals'], + 'value': '22' + }, + { + 'comparison_operator': '≠', + 'variable_selector': ['start', 'not_equals'], + 'value': '22' + }, + { + 'comparison_operator': '>', + 'variable_selector': ['start', 'greater_than'], + 'value': '22' + }, + { + 'comparison_operator': '<', + 'variable_selector': ['start', 'less_than'], + 'value': '22' + }, + { + 'comparison_operator': '≥', + 'variable_selector': ['start', 'greater_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': '≤', + 'variable_selector': ['start', 'less_than_or_equal'], + 'value': '22' + }, + { + 'comparison_operator': 'null', + 'variable_selector': ['start', 'null'] + }, + { + 'comparison_operator': 'not null', + 'variable_selector': ['start', 'not_null'] + }, + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) + pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') + pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') + pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') + pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') + pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') + pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') + pool.append_variable(node_id='start', variable_key_list=['empty'], value='') + pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') + pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) + pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) + pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) + pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) + pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) + pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) + pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is True + + +def test_execute_if_else_result_false(): + node = IfElseNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'if-else', + 'data': { + 'title': '123', + 'type': 'if-else', + 'logical_operator': 'or', + 'conditions': [ + { + 'comparison_operator': 'contains', + 'variable_selector': ['start', 'array_contains'], + 'value': 'ab' + }, + { + 'comparison_operator': 'not contains', + 'variable_selector': ['start', 'array_not_contains'], + 'value': 'ab' + } + ] + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def']) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['result'] is False From 6b19ba3bb2821733f4ac1be91266bfde7c0d9eeb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 13 Mar 2024 17:46:42 +0800 Subject: [PATCH 182/200] enhance: sandbox-docker-compose --- api/.env.example | 4 ++-- docker/docker-compose.middleware.yaml | 3 +++ docker/docker-compose.yaml | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/api/.env.example b/api/.env.example index c0942412ab..832c7e3bab 100644 --- a/api/.env.example +++ b/api/.env.example @@ -134,5 +134,5 @@ SSRF_PROXY_HTTPS_URL= BATCH_UPLOAD_LIMIT=10 # CODE EXECUTION CONFIGURATION -CODE_EXECUTION_ENDPOINT= -CODE_EXECUTION_API_KEY= +CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 +CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 8fba59c315..4f7965609b 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -55,9 +55,12 @@ services: sandbox: image: langgenius/dify-sandbox:latest restart: always + cap_add: + - SYS_ADMIN environment: # The DifySandbox configurations API_KEY: dify-sandbox + GIN_MODE: 'release' ports: - "8194:8194" diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ca6b6cbf1a..f066582ac8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -293,9 +293,12 @@ services: sandbox: image: langgenius/dify-sandbox:latest restart: always + cap_add: + - SYS_ADMIN environment: # The DifySandbox configurations API_KEY: dify-sandbox + GIN_MODE: release ports: - "8194:8194" From e5ff06bcb78a39691410fcff4e34528040c5b1b3 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 18:02:07 +0800 Subject: [PATCH 183/200] fix err typo --- api/core/workflow/nodes/if_else/if_else_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 9cb084b116..44a4091a2e 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -95,7 +95,7 @@ class IfElseNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, - process_datas=process_datas, + process_data=process_datas, error=str(e) ) @@ -107,7 +107,7 @@ class IfElseNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, - process_datas=process_datas, + process_data=process_datas, edge_source_handle="false" if not compare_result else "true", outputs={ "result": compare_result From 0614ddde7dedc0465eb827e40dc170d965f6651a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 13 Mar 2024 20:40:37 +0800 Subject: [PATCH 184/200] fix: allow None AuthorizationConfig --- .../workflow/nodes/http_request/entities.py | 17 +++++++++-- .../workflow/nodes/test_http.py | 30 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index ce806b6bdb..fbd4da3840 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,6 +1,6 @@ from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -17,7 +17,20 @@ class HttpRequestNodeData(BaseNodeData): header: Union[None, str] type: Literal['no-auth', 'api-key'] - config: Config + config: Optional[Config] + + @validator('config', always=True, pre=True) + def check_config(cls, v, values): + """ + Check config, if type is no-auth, config should be None, otherwise it should be a dict. + """ + if values['type'] == 'no-auth': + return None + else: + if not v or not isinstance(v, dict): + raise ValueError('config should be a dict') + + return v class Body(BaseModel): type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 6df8f6b673..584e1d80a5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -54,6 +54,36 @@ def test_get(setup_http_mock): assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_no_auth(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'no-auth', + 'config': None, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) def test_template(setup_http_mock): node = HttpRequestNode(config={ From 5a67c09b48d18a398a23b86be5f33c39fcabec0a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 20:54:23 +0800 Subject: [PATCH 185/200] use answer node instead of end in advanced chatbot --- api/services/workflow/workflow_converter.py | 67 ++++++++++++--------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 4c7e4db47a..78f79e02fa 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -19,7 +19,6 @@ from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.end.entities import EndNodeOutputType from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account @@ -149,10 +148,13 @@ class WorkflowConverter: graph = self._append_node(graph, llm_node) - # convert to end node by app mode - end_node = self._convert_to_end_node(app_model=app_model) - - graph = self._append_node(graph, end_node) + if new_app_mode == AppMode.WORKFLOW: + # convert to end node by app mode + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + else: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) app_model_config_dict = app_config.app_model_config_dict @@ -517,35 +519,44 @@ class WorkflowConverter: } } - def _convert_to_end_node(self, app_model: App) -> dict: + def _convert_to_end_node(self) -> dict: """ Convert to End Node - :param app_model: App instance :return: """ - if app_model.mode == AppMode.CHAT.value: - return { - "id": "end", - "position": None, - "data": { - "title": "END", - "type": NodeType.END.value, + # for original completion app + return { + "id": "end", + "position": None, + "data": { + "title": "END", + "type": NodeType.END.value, + "outputs": { + "variable": "result", + "value_selector": ["llm", "text"] } } - elif app_model.mode == AppMode.COMPLETION.value: - # for original completion app - return { - "id": "end", - "position": None, - "data": { - "title": "END", - "type": NodeType.END.value, - "outputs": { - "type": EndNodeOutputType.PLAIN_TEXT.value, - "plain_text_selector": ["llm", "text"] - } - } + } + + def _convert_to_answer_node(self) -> dict: + """ + Convert to Answer Node + :return: + """ + # for original chat app + return { + "id": "answer", + "position": None, + "data": { + "title": "ANSWER", + "type": NodeType.ANSWER.value, + "variables": { + "variable": "text", + "value_selector": ["llm", "text"] + }, + "answer": "{{text}}" } + } def _create_edge(self, source: str, target: str) -> dict: """ @@ -582,7 +593,7 @@ class WorkflowConverter: if app_model.mode == AppMode.COMPLETION.value: return AppMode.WORKFLOW else: - return AppMode.value_of(app_model.mode) + return AppMode.ADVANCED_CHAT def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: """ From 44c4d5be72d2fcfd2930de377015968f2f75ae22 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 13 Mar 2024 23:00:28 +0800 Subject: [PATCH 186/200] add answer output parse --- .../workflow_event_trigger_callback.py | 31 +--------- api/core/workflow/nodes/answer/answer_node.py | 50 +++++++++++++-- api/core/workflow/nodes/base_node.py | 14 +---- api/core/workflow/nodes/end/end_node.py | 38 +++--------- api/core/workflow/nodes/end/entities.py | 61 +------------------ api/core/workflow/workflow_engine_manager.py | 4 ++ api/services/workflow/workflow_converter.py | 4 +- .../core/workflow/nodes/test_answer.py | 56 +++++++++++++++++ .../{if_else_node.py => test_if_else.py} | 0 9 files changed, 120 insertions(+), 138 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_answer.py rename api/tests/unit_tests/core/workflow/nodes/{if_else_node.py => test_if_else.py} (100%) diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index ea7eb5688c..59ef44cd2e 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -5,7 +5,6 @@ from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, @@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self) -> None: """ @@ -118,31 +116,4 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): """ Publish text chunk """ - if node_id in self._streamable_node_ids: - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param graph: workflow graph - :return: - """ - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('data', {}).get('type') == NodeType.END.value: - if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids + pass diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 381ada1a1e..97ddafad01 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,4 +1,3 @@ -import time from typing import cast from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -32,14 +31,49 @@ class AnswerNode(BaseNode): variable_values[variable_selector.variable] = value + variable_keys = list(variable_values.keys()) + # format answer template template_parser = PromptTemplateParser(node_data.answer) - answer = template_parser.format(variable_values) + template_variable_keys = template_parser.variable_keys - # publish answer as stream - for word in answer: - self.publish_text_chunk(word) - time.sleep(10) # TODO for debug + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + split_template = [ + { + "type": "var" if self._is_variable(part, variable_keys) else "text", + "value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part + } + for part in template.split('Ω') if part + ] + + answer = [] + for part in split_template: + if part["type"] == "var": + value = variable_values.get(part["value"].replace('{{', '').replace('}}', '')) + answer_part = { + "type": "text", + "text": value + } + # TODO File + else: + answer_part = { + "type": "text", + "text": part["value"] + } + + if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": + answer[-1]["text"] += answer_part["text"] + else: + answer.append(answer_part) + + if len(answer) == 1 and answer[0]["type"] == "text": + answer = answer[0]["text"] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -49,6 +83,10 @@ class AnswerNode(BaseNode): } ) + def _is_variable(self, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index dfba9d0385..2da19bc409 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from models.workflow import WorkflowNodeExecutionStatus class UserFrom(Enum): @@ -80,16 +79,9 @@ class BaseNode(ABC): :param variable_pool: variable pool :return: """ - try: - result = self._run( - variable_pool=variable_pool - ) - except Exception as e: - # process unhandled exception - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) + result = self._run( + variable_pool=variable_pool + ) self.node_run_result = result return result diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2666ccc4f9..3241860c29 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -2,9 +2,9 @@ from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs +from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -20,34 +20,14 @@ class EndNode(BaseNode): """ node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - outputs_config = node_data.outputs + output_variables = node_data.outputs - outputs = None - if outputs_config: - if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: - plain_text_selector = outputs_config.plain_text_selector - if plain_text_selector: - outputs = { - 'text': variable_pool.get_variable_value( - variable_selector=plain_text_selector, - target_value_type=ValueType.STRING - ) - } - else: - outputs = { - 'text': '' - } - elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: - structured_variables = outputs_config.structured_variables - if structured_variables: - outputs = {} - for variable_selector in structured_variables: - variable_value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - outputs[variable_selector.variable] = variable_value - else: - outputs = {} + outputs = {} + for variable_selector in output_variables: + variable_value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + outputs[variable_selector.variable] = variable_value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index 32212ae7fa..ad4fc8f04f 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,68 +1,9 @@ -from enum import Enum -from typing import Optional - -from pydantic import BaseModel - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector -class EndNodeOutputType(Enum): - """ - END Node Output Types. - - none, plain-text, structured - """ - NONE = 'none' - PLAIN_TEXT = 'plain-text' - STRUCTURED = 'structured' - - @classmethod - def value_of(cls, value: str) -> 'OutputType': - """ - Get value of given output type. - - :param value: output type value - :return: output type - """ - for output_type in cls: - if output_type.value == value: - return output_type - raise ValueError(f'invalid output type value {value}') - - -class EndNodeDataOutputs(BaseModel): - """ - END Node Data Outputs. - """ - class OutputType(Enum): - """ - Output Types. - """ - NONE = 'none' - PLAIN_TEXT = 'plain-text' - STRUCTURED = 'structured' - - @classmethod - def value_of(cls, value: str) -> 'OutputType': - """ - Get value of given output type. - - :param value: output type value - :return: output type - """ - for output_type in cls: - if output_type.value == value: - return output_type - raise ValueError(f'invalid output type value {value}') - - type: OutputType = OutputType.NONE - plain_text_selector: Optional[list[str]] = None - structured_variables: Optional[list[VariableSelector]] = None - - class EndNodeData(BaseNodeData): """ END Node Data. """ - outputs: Optional[EndNodeDataOutputs] = None + outputs: list[VariableSelector] diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index ebc753537e..3109f9ea33 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,3 +1,4 @@ +import logging import time from typing import Optional @@ -41,6 +42,8 @@ node_classes = { NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, } +logger = logging.getLogger(__name__) + class WorkflowEngineManager: def get_default_configs(self) -> list[dict]: @@ -407,6 +410,7 @@ class WorkflowEngineManager: variable_pool=workflow_run_state.variable_pool ) except Exception as e: + logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") node_run_result = NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 78f79e02fa..953c5c5a3c 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -531,10 +531,10 @@ class WorkflowConverter: "data": { "title": "END", "type": NodeType.END.value, - "outputs": { + "outputs": [{ "variable": "result", "value_selector": ["llm", "text"] - } + }] } } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py new file mode 100644 index 0000000000..bad5d42a43 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + + +def test_execute_answer(): + node = AnswerNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'answer', + 'data': { + 'title': '123', + 'type': 'answer', + 'variables': [ + { + 'value_selector': ['llm', 'text'], + 'variable': 'text' + }, + { + 'value_selector': ['start', 'weather'], + 'variable': 'weather' + }, + ], + 'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.' + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') + pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + + +# TODO test files diff --git a/api/tests/unit_tests/core/workflow/nodes/if_else_node.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py similarity index 100% rename from api/tests/unit_tests/core/workflow/nodes/if_else_node.py rename to api/tests/unit_tests/core/workflow/nodes/test_if_else.py From 6633a92e1aef02aae56d6c0a1caa11aa3e7671fa Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:35:51 +0800 Subject: [PATCH 187/200] fix: http --- .../workflow/nodes/http_request/entities.py | 2 +- .../nodes/http_request/http_executor.py | 6 +- .../nodes/http_request/http_request_node.py | 2 +- .../workflow/nodes/test_http.py | 74 +++++++++++++++++++ 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index fbd4da3840..0683008954 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -33,7 +33,7 @@ class HttpRequestNodeData(BaseNodeData): return v class Body(BaseModel): - type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] + type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw', 'json'] data: Union[None, str] variables: list[VariableSelector] diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index c96d5f07d1..3d307be0d1 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -131,8 +131,6 @@ class HttpExecutor: self.headers['Content-Type'] = 'application/json' elif node_data.body.type == 'x-www-form-urlencoded': self.headers['Content-Type'] = 'application/x-www-form-urlencoded' - # elif node_data.body.type == 'form-data': - # self.headers['Content-Type'] = 'multipart/form-data' if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: body = {} @@ -152,8 +150,10 @@ class HttpExecutor: } else: self.body = urlencode(body) - else: + elif node_data.body.type in ['json', 'raw']: self.body = original_body + elif node_data.body.type == 'none': + self.body = '' def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index c83e331fa8..a914ae13ff 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -42,7 +42,7 @@ class HttpRequestNode(BaseNode): inputs=variables, outputs={ 'status_code': response.status_code, - 'body': response, + 'body': response.body, 'headers': response.headers }, process_data={ diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 584e1d80a5..8b94105b44 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -84,6 +84,41 @@ def test_no_auth(setup_http_mock): assert '?A=b' in data assert 'X-Header: 123' in data +@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) +def test_custom_authorization_header(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }], + 'method': 'get', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'custom', + 'api_key': 'Auth', + 'header': 'X-Auth', + }, + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': None, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + + data = result.process_data.get('request', '') + + assert '?A=b' in data + assert 'X-Header: 123' in data + assert 'X-Auth: Auth' in data + @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) def test_template(setup_http_mock): node = HttpRequestNode(config={ @@ -237,3 +272,42 @@ def test_form_data(setup_http_mock): assert '2' in data assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data + +def test_none_data(setup_http_mock): + node = HttpRequestNode(config={ + 'id': '1', + 'data': { + 'title': 'http', + 'desc': '', + 'variables': [{ + 'variable': 'args1', + 'value_selector': ['1', '123', 'args1'], + }, { + 'variable': 'args2', + 'value_selector': ['1', '123', 'args2'], + }], + 'method': 'post', + 'url': 'http://example.com', + 'authorization': { + 'type': 'api-key', + 'config': { + 'type': 'basic', + 'api_key':'ak-xxx', + 'header': 'api-key', + } + }, + 'headers': 'X-Header:123', + 'params': 'A:b', + 'body': { + 'type': 'none', + 'data': '123123123' + }, + } + }, **BASIC_NODE_DATA) + + result = node.run(pool) + data = result.process_data.get('request', '') + + assert 'api-key: Basic ak-xxx' in data + assert 'X-Header: 123' in data + assert '123123123' not in data \ No newline at end of file From fb6e5bf4d5f40165ef41c5e850ae50467300aec7 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 11:39:05 +0800 Subject: [PATCH 188/200] fix publish route --- api/controllers/console/app/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6f81da5691..d5967dd5ed 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -299,7 +299,7 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced- api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') api.add_resource(WorkflowTaskStopApi, '/apps//workflows/tasks//stop') api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/published') +api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' '/') From c2ded79cb2bd752866b42ca8b0a9640da1be9e66 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:58:56 +0800 Subject: [PATCH 189/200] fix: node type --- api/core/workflow/nodes/tool/tool_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c62e025e75..89c8389085 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -136,12 +136,12 @@ class ToolNode(BaseNode): @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping """ return { k.variable: k.value_selector - for k in cast(ToolNodeData, node_data).tool_parameters + for k in node_data.tool_parameters if k.variable_type == 'selector' } From 87a36a1fc8ba3c88646d884761a5e19b108fcefb Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 11:59:33 +0800 Subject: [PATCH 190/200] fix: linter --- api/core/workflow/nodes/tool/tool_node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 89c8389085..b03ad45e6c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -5,7 +5,6 @@ from core.file.file_obj import FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode From 72d2f76d2444a14ec34f6ba1dbf8d098241a96d3 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 12:12:26 +0800 Subject: [PATCH 191/200] fix default configs --- api/core/workflow/workflow_engine_manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3109f9ea33..a7379e6e99 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -54,10 +54,7 @@ class WorkflowEngineManager: for node_type, node_class in node_classes.items(): default_config = node_class.get_default_config() if default_config: - default_block_configs.append({ - 'type': node_type.value, - 'config': default_config - }) + default_block_configs.append(default_config) return default_block_configs From 737321da756dc16cd882f01204129bc0febc567b Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 12:17:15 +0800 Subject: [PATCH 192/200] add advanced chat apis support --- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/conversation.py | 8 ++++---- api/controllers/console/app/message.py | 4 ++-- api/controllers/console/app/statistic.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 4de4a6f3fe..29d89ae460 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): file = request.files['file'] diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 33711076f8..11dece3a9e 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -112,7 +112,7 @@ class CompletionConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -133,7 +133,7 @@ class ChatConversationApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_with_summary_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() @@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @marshal_with(conversation_detail_fields) def get(self, app_model, conversation_id): conversation_id = str(conversation_id) @@ -227,7 +227,7 @@ class ChatConversationDetailApi(Resource): @setup_required @login_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 111ec7d787..56d2e718e7 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -42,7 +42,7 @@ class ChatMessageListApi(Resource): @setup_required @login_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): @@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model, message_id): message_id = str(message_id) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 51fe53c0ec..d687b52dc8 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource): @setup_required @login_required @account_initialization_required - @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) + @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def get(self, app_model): account = current_user From 6e51ce123c66feb738eebbe8740e2ebb509612a2 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 12:56:25 +0800 Subject: [PATCH 193/200] fix: null conversation id --- ...nable_tool_file_without_conversation_id.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py new file mode 100644 index 0000000000..d91288bcf5 --- /dev/null +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -0,0 +1,36 @@ +"""enable tool file without conversation id + +Revision ID: 563cf8bf777b +Revises: b5429b71023c +Create Date: 2024-03-14 04:54:56.679506 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '563cf8bf777b' +down_revision = 'b5429b71023c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### From 74e644be1ca1436b1c7ef265158fceab761662f4 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 12:56:57 +0800 Subject: [PATCH 194/200] fix: linter --- .../563cf8bf777b_enable_tool_file_without_conversation_id.py | 1 - api/models/tools.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index d91288bcf5..299f442de9 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -6,7 +6,6 @@ Create Date: 2024-03-14 04:54:56.679506 """ from alembic import op -import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. diff --git a/api/models/tools.py b/api/models/tools.py index bceef7a829..4bdf2503ce 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -218,7 +218,7 @@ class ToolFile(db.Model): # tenant id tenant_id = db.Column(UUID, nullable=False) # conversation id - conversation_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, nullable=True) # file key file_key = db.Column(db.String(255), nullable=False) # mime type From dc53362506f8453237030b560cf2a8d884f8290b Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 13:24:48 +0800 Subject: [PATCH 195/200] fix: conversation_id equals to none --- api/core/workflow/nodes/tool/tool_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b03ad45e6c..ca217182cc 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -78,7 +78,7 @@ class ToolNode(BaseNode): messages=messages, user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id='', + conversation_id=None, ) # extract plain text and files files = self._extract_tool_response_binary(messages) From ede65eca4d9c14484ea1b4674febae2b1ddb20c9 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 16:38:22 +0800 Subject: [PATCH 196/200] fix: tool --- api/core/workflow/nodes/tool/entities.py | 11 +++++++++-- api/core/workflow/nodes/tool/tool_node.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 0b3bf76aac..7eb3cf655b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -3,7 +3,6 @@ from typing import Literal, Optional, Union from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector ToolParameterValue = Union[str, int, float, bool] @@ -16,8 +15,10 @@ class ToolEntity(BaseModel): tool_configurations: dict[str, ToolParameterValue] class ToolNodeData(BaseNodeData, ToolEntity): - class ToolInput(VariableSelector): + class ToolInput(BaseModel): + variable: str variable_type: Literal['selector', 'static'] + value_selector: Optional[list[str]] value: Optional[str] @validator('value') @@ -25,6 +26,12 @@ class ToolNodeData(BaseNodeData, ToolEntity): if values['variable_type'] == 'static' and value is None: raise ValueError('value is required for static variable') return value + + @validator('value_selector') + def check_value_selector(cls, value_selector, values, **kwargs): + if values['variable_type'] == 'selector' and value_selector is None: + raise ValueError('value_selector is required for selector variable') + return value_selector """ Tool Node Schema diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ca217182cc..d0bfd9e797 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -44,7 +44,7 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters, - error=f'Failed to invoke tool: {str(e)}' + error=f'Failed to invoke tool: {str(e)}', ) # convert tool messages @@ -56,6 +56,7 @@ class ToolNode(BaseNode): 'text': plain_text, 'files': files }, + inputs=parameters ) def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: From 1cfeb989f77fde324866ef9897268cabbcb3c747 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Mar 2024 19:17:27 +0800 Subject: [PATCH 197/200] fix: code default output --- api/core/workflow/nodes/code/code_node.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 5dfe398711..0b46f86e9d 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -53,12 +53,12 @@ class CodeNode(BaseNode): ], "code_language": "javascript", "code": JAVASCRIPT_DEFAULT_CODE, - "outputs": [ - { - "variable": "result", - "variable_type": "number" + "outputs": { + "result": { + "type": "number", + "children": None } - ] + } } } @@ -77,12 +77,12 @@ class CodeNode(BaseNode): ], "code_language": "python3", "code": PYTHON_DEFAULT_CODE, - "outputs": [ - { - "variable": "result", - "variable_type": "number" + "outputs": { + "result": { + "type": "number", + "children": None } - ] + } } } From 12eb2363646b316999a766633a7cee723e623e06 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 14 Mar 2024 20:49:53 +0800 Subject: [PATCH 198/200] answer stream output support --- .../advanced_chat/generate_task_pipeline.py | 277 +++++++++++++++++- .../workflow_event_trigger_callback.py | 39 +-- .../apps/message_based_app_queue_manager.py | 6 +- .../workflow_event_trigger_callback.py | 2 +- api/core/app/entities/queue_entities.py | 11 +- .../callbacks/base_workflow_callback.py | 2 +- api/core/workflow/nodes/answer/answer_node.py | 129 +++++--- api/core/workflow/nodes/answer/entities.py | 26 ++ api/core/workflow/nodes/base_node.py | 9 +- api/core/workflow/nodes/llm/llm_node.py | 2 +- 10 files changed, 413 insertions(+), 90 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e8463e59d3..ca4b143027 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,7 +2,7 @@ import json import logging import time from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from pydantic import BaseModel, Extra @@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, QueueMessageFileEvent, @@ -34,6 +35,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account @@ -51,15 +54,26 @@ from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) +class StreamGenerateRoute(BaseModel): + """ + StreamGenerateRoute entity + """ + answer_node_id: str + generate_route: list[GenerateRouteChunk] + current_route_position: int = 0 + + class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity """ workflow_node_execution_id: str + node_type: NodeType start_at: float class Config: @@ -77,9 +91,11 @@ class TaskState(BaseModel): total_tokens: int = 0 total_steps: int = 0 - running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} latest_node_execution_info: Optional[NodeExecutionInfo] = None + current_stream_generate_state: Optional[StreamGenerateRoute] = None + class Config: """Configuration for this pydantic object.""" @@ -122,6 +138,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): self._output_moderation_handler = self._init_output_moderation() self._stream = stream + if stream: + self._stream_generate_routes = self._get_stream_generate_routes() + else: + self._stream_generate_routes = None + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. @@ -290,6 +311,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(data) break + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, @@ -309,7 +335,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): } yield self._yield_response(workflow_run_response) - + elif isinstance(event, QueueAdvancedChatMessageEndEvent): # response moderation if self._output_moderation_handler: self._output_moderation_handler.stop_thread() @@ -390,6 +416,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): yield self._yield_response(response) elif isinstance(event, QueueTextChunkEvent): + if not self._is_stream_out_support( + event=event + ): + continue + delta_text = event.text if delta_text is None: continue @@ -467,20 +498,28 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): latest_node_execution_info = TaskState.NodeExecutionInfo( workflow_node_execution_id=workflow_node_execution.id, + node_type=event.node_type, start_at=time.perf_counter() ) - self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 db.session.close() + # search stream_generate_routes if node id is answer start at node + if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: + self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + + # stream outputs from start + self._generate_stream_outputs_when_node_start() + return workflow_node_execution def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): @@ -508,8 +547,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] + # stream outputs when node finished + self._generate_stream_outputs_when_node_finished() db.session.close() @@ -517,7 +556,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ -> WorkflowRun: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_run = (db.session.query(WorkflowRun) + .filter(WorkflowRun.id == self._task_state.workflow_run_id).first()) if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( workflow_run=workflow_run, @@ -642,7 +682,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): QuotaExceededError: { 'code': 'provider_quota_exceeded', 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", + "Please go to Settings -> Model Provider to complete your own provider credentials.", 'status': 400 }, ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, @@ -660,10 +700,10 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): else: logging.error(e) data = { - 'code': 'internal_server_error', + 'code': 'internal_server_error', 'message': 'Internal Server Error, please contact support.', 'status': 500 - } + } return { 'event': 'error', @@ -730,3 +770,218 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ), queue_manager=self._queue_manager ) + + def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]: + """ + Get stream generate routes. + :return: + """ + # find all answer nodes + graph = self._workflow.graph_dict + answer_node_configs = [ + node for node in graph['nodes'] + if node.get('data', {}).get('type') == NodeType.ANSWER.value + ] + + # parse stream output node value selectors of answer nodes + stream_generate_routes = {} + for node_config in answer_node_configs: + # get generate route for stream output + answer_node_id = node_config['id'] + generate_route = AnswerNode.extract_generate_route_selectors(node_config) + start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id) + if not start_node_id: + continue + + stream_generate_routes[start_node_id] = StreamGenerateRoute( + answer_node_id=answer_node_id, + generate_route=generate_route + ) + + return stream_generate_routes + + def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ + -> Optional[str]: + """ + Get answer start at node id. + :param graph: graph + :param target_node_id: target node ID + :return: + """ + nodes = graph.get('nodes') + edges = graph.get('edges') + + # fetch all ingoing edges from source node + ingoing_edge = None + for edge in edges: + if edge.get('target') == target_node_id: + ingoing_edge = edge + break + + if not ingoing_edge: + return None + + source_node_id = ingoing_edge.get('source') + source_node = next((node for node in nodes if node.get('id') == source_node_id), None) + if not source_node: + return None + + node_type = source_node.get('data', {}).get('type') + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER + ]: + start_node_id = target_node_id + elif node_type == NodeType.START.value: + start_node_id = source_node_id + else: + start_node_id = self._get_answer_start_at_node_id(graph, source_node_id) + + return start_node_id + + def _generate_stream_outputs_when_node_start(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + for route_chunk in self._task_state.current_stream_generate_state.generate_route: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + else: + break + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _generate_stream_outputs_when_node_finished(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + route_chunks = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + route_chunk_node_id = value_selector[0] + + # check chunk node id is before current node id or equal to current node id + if route_chunk_node_id not in self._task_state.ran_node_execution_infos: + break + + latest_node_execution_info = self._task_state.latest_node_execution_info + + # get route chunk node execution info + route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] + if (route_chunk_node_execution_info.node_type == NodeType.LLM + and latest_node_execution_info.node_type == NodeType.LLM): + # only LLM support chunk stream output + self._task_state.current_stream_generate_state.current_route_position += 1 + continue + + # get route chunk node execution + route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first() + + outputs = route_chunk_node_execution.outputs_dict + + # get value from outputs + value = None + for key in value_selector[1:]: + if not value: + value = outputs.get(key) + else: + value = value.get(key) + + if value: + text = None + if isinstance(value, str | int | float): + text = str(value) + elif isinstance(value, object): # TODO FILE + # convert file to markdown + text = f'![]({value.get("url")})' + pass + + if text: + for token in text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.metadata: + return True + + if 'node_id' not in event.metadata: + return True + + node_type = event.metadata.get('node_type') + stream_output_value_selector = event.metadata.get('value_selector') + if not stream_output_value_selector: + return False + + if not self._task_state.current_stream_generate_state: + return False + + route_chunk = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position] + + if route_chunk.type != 'var': + return False + + if node_type != NodeType.LLM: + # only LLM support chunk stream output + return False + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + return False + + return True diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index b4a6a9602f..972fda2d49 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self) -> None: """ @@ -114,34 +113,16 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ - if node_id in self._streamable_node_ids: - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param graph: workflow graph - :return: - """ - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('data', {}).get('type') == NodeType.END.value: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids + self._queue_manager.publish( + QueueTextChunkEvent( + text=text, + metadata={ + "node_id": node_id, + **metadata + } + ), PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 6d0a71f495..f4ff44ddda 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -3,12 +3,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, + QueueAdvancedChatMessageEndEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, ) @@ -54,8 +53,7 @@ class MessageBasedAppQueueManager(AppQueueManager): if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + | QueueAdvancedChatMessageEndEvent): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 59ef44cd2e..e5a8e8d374 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -112,7 +112,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 153607e1b4..5c31996fd3 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -17,6 +17,7 @@ class QueueEvent(Enum): AGENT_MESSAGE = "agent_message" MESSAGE_REPLACE = "message_replace" MESSAGE_END = "message_end" + ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end" WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" @@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ event = QueueEvent.TEXT_CHUNK text: str + metadata: Optional[dict] = None class QueueAgentMessageEvent(AppQueueEvent): @@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent): QueueMessageEndEvent entity """ event = QueueEvent.MESSAGE_END - llm_result: LLMResult + llm_result: Optional[LLMResult] = None + + +class QueueAdvancedChatMessageEndEvent(AppQueueEvent): + """ + QueueAdvancedChatMessageEndEvent entity + """ + event = QueueEvent.ADVANCED_CHAT_MESSAGE_END class QueueWorkflowStartedEvent(AppQueueEvent): diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9594fa2037..1f5472b430 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -64,7 +64,7 @@ class BaseWorkflowCallback(ABC): raise NotImplementedError @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 97ddafad01..d8ff5cb6f6 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -4,7 +4,12 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool -from core.workflow.nodes.answer.entities import AnswerNodeData +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) from core.workflow.nodes.base_node import BaseNode from models.workflow import WorkflowNodeExecutionStatus @@ -22,6 +27,40 @@ class AnswerNode(BaseNode): node_data = self.node_data node_data = cast(self._node_data_cls, node_data) + # generate routes + generate_routes = self.extract_generate_route_from_node_data(node_data) + + answer = [] + for part in generate_routes: + if part.type == "var": + part = cast(VarGenerateRouteChunk, part) + value_selector = part.value_selector + value = variable_pool.get_variable_value( + variable_selector=value_selector, + target_value_type=ValueType.STRING + ) + + answer_part = { + "type": "text", + "text": value + } + # TODO File + else: + part = cast(TextGenerateRouteChunk, part) + answer_part = { + "type": "text", + "text": part.text + } + + if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": + answer[-1]["text"] += answer_part["text"] + else: + answer.append(answer_part) + + if len(answer) == 1 and answer[0]["type"] == "text": + answer = answer[0]["text"] + + # re-fetch variable values variable_values = {} for variable_selector in node_data.variables: value = variable_pool.get_variable_value( @@ -31,7 +70,39 @@ class AnswerNode(BaseNode): variable_values[variable_selector.variable] = value - variable_keys = list(variable_values.keys()) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variable_values, + outputs={ + "answer": answer + } + ) + + @classmethod + def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + node_data = cast(cls._node_data_cls, node_data) + + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } + + variable_keys = list(value_selector_mapping.keys()) # format answer template template_parser = PromptTemplateParser(node_data.answer) @@ -44,46 +115,24 @@ class AnswerNode(BaseNode): for var in variable_keys: template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - split_template = [ - { - "type": "var" if self._is_variable(part, variable_keys) else "text", - "value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part - } - for part in template.split('Ω') if part - ] + generate_routes = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) - answer = [] - for part in split_template: - if part["type"] == "var": - value = variable_values.get(part["value"].replace('{{', '').replace('}}', '')) - answer_part = { - "type": "text", - "text": value - } - # TODO File - else: - answer_part = { - "type": "text", - "text": part["value"] - } + return generate_routes - if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": - answer[-1]["text"] += answer_part["text"] - else: - answer.append(answer_part) - - if len(answer) == 1 and answer[0]["type"] == "text": - answer = answer[0]["text"] - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variable_values, - outputs={ - "answer": answer - } - ) - - def _is_variable(self, part, variable_keys): + @classmethod + def _is_variable(cls, part, variable_keys): cleaned_part = part.replace('{{', '').replace('}}', '') return part.startswith('{{') and cleaned_part in variable_keys diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 7c6fed3e4e..8aed752ccb 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,6 @@ + +from pydantic import BaseModel + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData): """ variables: list[VariableSelector] = [] answer: str + + +class GenerateRouteChunk(BaseModel): + """ + Generate Route Chunk. + """ + type: str + + +class VarGenerateRouteChunk(GenerateRouteChunk): + """ + Var Generate Route Chunk. + """ + type: str = "var" + value_selector: list[str] + + +class TextGenerateRouteChunk(GenerateRouteChunk): + """ + Text Generate Route Chunk. + """ + type: str = "text" + text: str diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 2da19bc409..7cc9c6ee3d 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -86,17 +86,22 @@ class BaseNode(ABC): self.node_run_result = result return result - def publish_text_chunk(self, text: str) -> None: + def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: """ Publish text chunk :param text: chunk text + :param value_selector: value selector :return: """ if self.callbacks: for callback in self.callbacks: callback.on_node_text_chunk( node_id=self.node_id, - text=text + text=text, + metadata={ + "node_type": self.node_type, + "value_selector": value_selector + } ) @classmethod diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 9285bbe74e..cb5a333091 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -169,7 +169,7 @@ class LLMNode(BaseNode): text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text) + self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) if not model: model = result.model From 785dfc5c0085cc93e64480d5f03b7eaf00c1b57c Mon Sep 17 00:00:00 2001 From: jyong Date: Fri, 15 Mar 2024 14:40:53 +0800 Subject: [PATCH 199/200] dataset retrival --- .../dataset_multi_retriever_tool.py | 194 ++++++++++ .../dataset_retriever_tool.py | 159 ++++++++ .../nodes/knowledge_retrieval/entities.py | 52 +++ .../knowledge_retrieval.py | 0 .../knowledge_retrieval_node.py | 364 +++++++++++++++++- 5 files changed, 766 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/entities.py create mode 100644 api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py new file mode 100644 index 0000000000..d9934acff9 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_multi_retriever_tool.py @@ -0,0 +1,194 @@ +import threading +from typing import Optional + +from flask import Flask, current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(BaseTool): + """Tool for querying multi dataset.""" + name: str = "dataset-" + args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + tenant_id: str + dataset_ids: list[str] + top_k: int = 2 + score_threshold: Optional[float] = None + reranking_provider_name: str + reranking_model_name: str + return_resource: bool + retriever_from: str + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + + @classmethod + def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): + return cls( + name=f'dataset-{tenant_id}', + tenant_id=tenant_id, + dataset_ids=dataset_ids, + **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'all_documents': all_documents, + 'hit_callbacks': self.hit_callbacks + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=self.reranking_provider_name, + model_type=ModelType.RERANK, + model=self.reranking_model_name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(all_documents) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + } + + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler]): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + for hit_callback in hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + if documents: + all_documents.extend(documents) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py new file mode 100644 index 0000000000..13331d981b --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/dataset_retriever_tool.py @@ -0,0 +1,159 @@ +from typing import Optional + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.datasource.retrieval_service import RetrievalService +from extensions.ext_database import db +from models.dataset import Dataset, Document, DocumentSegment + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} + + +class DatasetRetrieverToolInput(BaseModel): + query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") + + +class DatasetRetrieverTool(BaseTool): + """Tool for querying a Dataset.""" + name: str = "dataset" + args_schema: type[BaseModel] = DatasetRetrieverToolInput + description: str = "use this to retrieve a dataset. " + + tenant_id: str + dataset_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + @classmethod + def from_dataset(cls, dataset: Dataset, **kwargs): + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + return cls( + name=f'dataset-{dataset.id}', + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + description=description, + **kwargs + ) + + def _run(self, query: str) -> str: + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == self.dataset_id + ).first() + + if not dataset: + return '' + + for hit_callback in self.hit_callbacks: + hit_callback.on_query(query, dataset.id) + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) + else: + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from, + 'score': document_score_list.get(segment.index_node_id, None) + + } + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py new file mode 100644 index 0000000000..905ee1f80d --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -0,0 +1,52 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel + +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + provider: str + mode: str + + +class MultipleRetrievalConfig(BaseModel): + """ + Multiple Retrieval Config. + """ + top_k: int + score_threshold: Optional[float] + reranking_model: RerankingModelConfig + + +class ModelConfig(BaseModel): + """ + Model Config. + """ + provider: str + name: str + mode: str + completion_params: dict[str, Any] = {} + + +class SingleRetrievalConfig(BaseModel): + """ + Single Retrieval Config. + """ + model: ModelConfig + + +class KnowledgeRetrievalNodeData(BaseNodeData): + """ + Knowledge retrieval Node Data. + """ + variables: list[VariableSelector] + dataset_ids: list[str] + retrieval_mode: Literal['single', 'multiple'] + multiple_retrieval_config: MultipleRetrievalConfig + singleRetrievalConfig: SingleRetrievalConfig diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7b8344418b..1ccdbf971c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,13 +1,371 @@ +import threading +from typing import cast, Any + +from flask import current_app, Flask + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService +from core.rerank.rerank import RerankRunner from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, Document +from models.workflow import WorkflowNodeExecutionStatus +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} class KnowledgeRetrievalNode(BaseNode): + + _node_data_cls = KnowledgeRetrievalNodeData + _node_type = NodeType.TOOL + def _run(self, variable_pool: VariablePool) -> NodeRunResult: - pass + node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + + # extract variables + variables = { + variable_selector.variable: variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector) + for variable_selector in node_data.variables + } + + # retrieve knowledge + try: + outputs = self._fetch_dataset_retriever( + node_data=node_data, variables=variables + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=variables, + process_data=None, + outputs=outputs + ) + + except Exception as e: + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e) + ) + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + """ + A dataset tool is a tool that can be used to retrieve information from a dataset + :param node_data: node data + :param variables: variables + """ + tools = [] + available_datasets = [] + dataset_ids = node_data.dataset_ids + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + all_documents = [] + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + all_documents = self._single_retrieve(available_datasets, node_data, variables) + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self._multiple_retrieve(available_datasets, node_data, variables) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + context_list = [] + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + + source = { + 'metadata': { + '_source': 'knowledge', + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'document_data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': 'workflow', + 'score': document_score_list.get(segment.index_node_id, None), + 'segment_hit_count': segment.hit_count, + 'segment_word_count': segment.word_count, + 'segment_position': segment.position + } + } + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + + return context_list @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - pass + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + return { + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + } + + def _single_retrieve(self, available_datasets, node_data, variables): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + prompt_messages = [ + SystemPromptMessage(content='You are a helpful AI assistant.'), + UserPromptMessage(content=variables['#query#']) + ] + result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=tools, + stream=False, + model_parameters={ + 'temperature': 0.2, + 'top_p': 0.3, + 'max_tokens': 1500 + } + ) + + if result.message.tool_calls: + # get retrieval model config + function_call_name = result.message.tool_calls[0].function.name + dataset = db.session.query(Dataset).filter( + Dataset.id == function_call_name + ).first() + if dataset: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + # get retrieval method + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] + # get score threshold + score_threshold = .0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + return results + + + + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + """ + Fetch model config + :param node_data: node data + :return: + """ + model_name = node_data.singleRetrievalConfig.model.name + provider_name = node_data.singleRetrievalConfig.model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=provider_name, + model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, + model_type=ModelType.LLM + ) + + if provider_model is None: + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data.singleRetrievalConfig.model.completion_params + stop = [] + if 'stop' in completion_params: + stop = completion_params['stop'] + del completion_params['stop'] + + # get model mode + model_mode = node_data.singleRetrievalConfig.model.mode + if not model_mode: + raise ValueError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema( + model_name, + model_credentials + ) + + if not model_schema: + raise ValueError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _multiple_retrieve(self, available_datasets, node_data, variables): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': variables['#query#'], + 'top_k': node_data.multiple_retrieval_config.top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=node_data.multiple_retrieval_config.reranking_model.provider, + model_type=ModelType.RERANK, + model=node_data.multiple_retrieval_config.reranking_model.name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(variables['#query#'], all_documents, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.top_k) + + return all_documents + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) \ No newline at end of file From 9b57b4c6c8591473a4b6f88b3e6d58f0ab5ace53 Mon Sep 17 00:00:00 2001 From: jyong Date: Fri, 15 Mar 2024 16:14:32 +0800 Subject: [PATCH 200/200] dataset retrival --- .../knowledge_retrieval_node.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1ccdbf971c..a501113dc3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -33,10 +33,10 @@ default_retrieval_model = { 'score_threshold_enabled': False } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData - _node_type = NodeType.TOOL + _node_type = NodeType.KNOWLEDGE_RETRIEVAL def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) @@ -67,7 +67,9 @@ class KnowledgeRetrievalNode(BaseNode): inputs=variables, error=str(e) ) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]: + + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[ + dict[str, Any]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param node_data: node data @@ -224,14 +226,14 @@ class KnowledgeRetrievalNode(BaseNode): if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'], - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=variables['#query#'], + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) return results - - - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ + ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -333,7 +335,7 @@ class KnowledgeRetrievalNode(BaseNode): return all_documents - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): dataset = db.session.query(Dataset).filter( Dataset.tenant_id == self.tenant_id, @@ -368,4 +370,4 @@ class KnowledgeRetrievalNode(BaseNode): if retrieval_model['reranking_enable'] else None ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents)