mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
fix: gpustack llm and text_embedding model url path wrong after edited
This commit is contained in:
parent
409cc7d9b0
commit
5ba875f96b
@ -1,7 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> LLMResult | Generator:
|
) -> LLMResult | Generator:
|
||||||
|
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||||
return super()._invoke(
|
return super()._invoke(
|
||||||
model,
|
model,
|
||||||
credentials,
|
compatible_credentials,
|
||||||
prompt_messages,
|
prompt_messages,
|
||||||
model_parameters,
|
model_parameters,
|
||||||
tools,
|
tools,
|
||||||
@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
self._add_custom_parameters(credentials)
|
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||||
super().validate_credentials(model, credentials)
|
super().validate_credentials(model, compatible_credentials)
|
||||||
|
|
||||||
|
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||||
|
credentials = credentials.copy()
|
||||||
|
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||||
|
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||||
|
return credentials
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_custom_parameters(credentials: dict) -> None:
|
def _add_custom_parameters(credentials: dict) -> None:
|
||||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
|
||||||
credentials["mode"] = "chat"
|
credentials["mode"] = "chat"
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.text_embedding_entities import (
|
from core.model_runtime.entities.text_embedding_entities import (
|
||||||
TextEmbeddingResult,
|
TextEmbeddingResult,
|
||||||
@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
|
|||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
) -> TextEmbeddingResult:
|
) -> TextEmbeddingResult:
|
||||||
return super()._invoke(model, credentials, texts, user, input_type)
|
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||||
|
return super()._invoke(model, compatible_credentials, texts, user, input_type)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
self._add_custom_parameters(credentials)
|
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||||
super().validate_credentials(model, credentials)
|
super().validate_credentials(model, compatible_credentials)
|
||||||
|
|
||||||
@staticmethod
|
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||||
def _add_custom_parameters(credentials: dict) -> None:
|
credentials = credentials.copy()
|
||||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||||
|
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||||
|
return credentials
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user