mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/feat/workflow-backend' into feat/workflow-backend
This commit is contained in:
commit
41d9fdee50
|
|
@ -17,10 +17,9 @@ class BedrockProvider(ModelProvider):
|
|||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-pro` model for validate,
|
||||
bedrock_validate_model_name = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1')
|
||||
model_instance.validate_credentials(
|
||||
model='amazon.titan-text-lite-v1',
|
||||
model=bedrock_validate_model_name,
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -69,3 +69,12 @@ provider_credential_schema:
|
|||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- variable: model_for_validation
|
||||
required: false
|
||||
label:
|
||||
en_US: Available Model Name
|
||||
zh_Hans: 可用模型名称
|
||||
type: text-input
|
||||
placeholder:
|
||||
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
|
||||
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
|
||||
|
|
|
|||
|
|
@ -656,6 +656,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
if assistant_message_function_call:
|
||||
# start of stream function call
|
||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||
if delta_assistant_message_function_call_storage.arguments is None:
|
||||
delta_assistant_message_function_call_storage.arguments = ''
|
||||
if not has_finish_reason:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
model: text-embedding-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
model: text-embedding-v2
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 2048
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
|
||||
|
||||
class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Tongyi text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
dashscope.api_key = credentials_kwargs["dashscope_api_key"]
|
||||
embeddings, embedding_used_tokens = self.embed_documents(model, texts)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
dashscope.api_key = credentials_kwargs["dashscope_api_key"]
|
||||
# call embedding model
|
||||
self.embed_documents(model=model, texts=["ping"])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def embed_documents(model: str, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""Call out to Tongyi's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text, and tokens usage.
|
||||
"""
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
for text in texts:
|
||||
response = dashscope.TextEmbedding.call(model=model, input=text, text_type="document")
|
||||
data = response.output["embeddings"][0]
|
||||
embeddings.append(data["embedding"])
|
||||
embedding_used_tokens += response.usage["total_tokens"]
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def _calc_response_usage(
|
||||
self, model: str, credentials: dict, tokens: int
|
||||
) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
|
@ -17,6 +17,7 @@ help:
|
|||
supported_model_types:
|
||||
- llm
|
||||
- tts
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
|
|
|||
|
|
@ -32,3 +32,8 @@ parameter_rules:
|
|||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ parameter_rules:
|
|||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""add_tenant_id_db_index
|
||||
|
||||
Revision ID: a8f9b3c45e4a
|
||||
Revises: 16830a790f0f
|
||||
Create Date: 2024-03-18 05:07:35.588473
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a8f9b3c45e4a'
|
||||
down_revision = '16830a790f0f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('document_segments', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.create_index('document_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_tenant_idx')
|
||||
|
||||
with op.batch_alter_table('document_segments', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_tenant_idx')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -11,7 +11,7 @@ from sqlalchemy.dialects import postgresql
|
|||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b289e2408ee2'
|
||||
down_revision = '16830a790f0f'
|
||||
down_revision = 'a8f9b3c45e4a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ class Document(db.Model):
|
|||
db.PrimaryKeyConstraint('id', name='document_pkey'),
|
||||
db.Index('document_dataset_id_idx', 'dataset_id'),
|
||||
db.Index('document_is_paused_idx', 'is_paused'),
|
||||
db.Index('document_tenant_idx', 'tenant_id'),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
|
|
@ -334,6 +335,7 @@ class DocumentSegment(db.Model):
|
|||
db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
|
||||
db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
|
||||
db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
|
||||
db.Index('document_segment_tenant_idx', 'tenant_id'),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
|
|
|
|||
Loading…
Reference in New Issue