From 5a679ed396ce075c233ea090fb1df023fcb132c8 Mon Sep 17 00:00:00 2001 From: JzoNg Date: Thu, 7 Nov 2024 21:50:22 +0800 Subject: [PATCH] provider compatible in model_config --- .../components/app/configuration/index.tsx | 36 +++++++++++++++---- web/app/components/workflow/utils.ts | 8 +---- web/utils/index.ts | 7 ++++ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index bf6c5e79c8..2480524ce5 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -71,8 +71,9 @@ import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' import { fetchFileUploadConfig } from '@/service/common' +import { correctProvider } from '@/utils' -interface PublishConfig { +type PublishConfig = { modelConfig: ModelConfig completionParams: FormValue } @@ -156,6 +157,7 @@ const Configuration: FC = () => { const setCompletionParams = (value: FormValue) => { const params = { ...value } + // eslint-disable-next-line ts/no-use-before-define if ((!params.stop || params.stop.length === 0) && (modeModeTypeRef.current === ModelModeType.completion)) { params.stop = getTempStop() setTempStop([]) @@ -164,7 +166,7 @@ const Configuration: FC = () => { } const [modelConfig, doSetModelConfig] = useState({ - provider: 'openai', + provider: 'langgenius/openai/openai', model_id: 'gpt-3.5-turbo', mode: ModelModeType.unset, configs: { @@ -187,7 +189,7 @@ const Configuration: FC = () => { const isAgent = mode === 'agent-chat' - const isOpenAI = modelConfig.provider === 'openai' + const isOpenAI = modelConfig.provider === 'langgenius/openai/openai' const [collectionList, setCollectionList] = useState([]) useEffect(() => { @@ -356,6 +358,7 @@ const Configuration: FC = () => { const [canReturnToSimpleMode, setCanReturnToSimpleMode] = useState(true) const setPromptMode = async (mode: PromptMode) => { if (mode === PromptMode.advanced) { + // eslint-disable-next-line ts/no-use-before-define await migrateToDefaultPrompt() setCanReturnToSimpleMode(true) } @@ -540,8 +543,19 @@ const Configuration: FC = () => { if (modelConfig.retriever_resource) setCitationConfig(modelConfig.retriever_resource) - if (modelConfig.annotation_reply) - setAnnotationConfig(modelConfig.annotation_reply, true) + if (modelConfig.annotation_reply) { + let annotationConfig = modelConfig.annotation_reply + if (modelConfig.annotation_reply.enabled) { + annotationConfig = { + ...modelConfig.annotation_reply, + embedding_model: { + ...modelConfig.annotation_reply.embedding_model, + embedding_provider_name: correctProvider(modelConfig.annotation_reply.embedding_model.embedding_provider_name), + }, + } + } + setAnnotationConfig(annotationConfig, true) + } if (modelConfig.sensitive_word_avoidance) setModerationConfig(modelConfig.sensitive_word_avoidance) @@ -551,7 +565,7 @@ const Configuration: FC = () => { const config = { modelConfig: { - provider: model.provider, + provider: correctProvider(model.provider), model_id: model.name, mode: model.mode, configs: { @@ -605,6 +619,10 @@ const Configuration: FC = () => { ...tool, isDeleted: res.deleted_tools?.includes(tool.tool_name), notAuthor: collectionList.find(c => tool.provider_id === c.id)?.is_team_authorization === false, + ...(tool.provider_type === 'builtin' ? { + provider_id: correctProvider(tool.provider_name), + provider_name: correctProvider(tool.provider_name), + } : {}), } }), } : DEFAULT_AGENT_SETTING, @@ -622,6 +640,12 @@ const Configuration: FC = () => { retrieval_model: RETRIEVE_TYPE.multiWay, ...modelConfig.dataset_configs, ...retrievalConfig, + ...(retrievalConfig.reranking_model ? { + reranking_model: { + ...retrievalConfig.reranking_model, + reranking_provider_name: correctProvider(modelConfig.dataset_configs.reranking_model.reranking_provider_name), + }, + } : {}), }) setHasFetchedDetail(true) }) diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 0264b00f3d..f7d15dae21 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -35,6 +35,7 @@ import type { ToolNodeType } from './nodes/tool/types' import type { IterationNodeType } from './nodes/iteration/types' import { CollectionType } from '@/app/components/tools/types' import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { correctProvider } from '@/utils' const WHITE = 'WHITE' const GRAY = 'GRAY' @@ -212,13 +213,6 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => { } } -export const correctProvider = (provider: string) => { - if (provider.includes('/')) - return provider - - return `langgenius/${provider}/${provider}` -} - export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges)) const firstNode = nodes[0] diff --git a/web/utils/index.ts b/web/utils/index.ts index 7aa6fef0a8..d165596fe3 100644 --- a/web/utils/index.ts +++ b/web/utils/index.ts @@ -57,3 +57,10 @@ export async function fetchWithRetry(fn: Promise, retries = 3): Prom return [null, res] } } + +export const correctProvider = (provider: string) => { + if (provider.includes('/')) + return provider + + return `langgenius/${provider}/${provider}` +}