diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index aa322fc664..96cb90280e 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -17,10 +17,9 @@ class BedrockProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.LLM) - - # Use `gemini-pro` model for validate, + bedrock_validate_model_name = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') model_instance.validate_credentials( - model='amazon.titan-text-lite-v1', + model=bedrock_validate_model_name, credentials=credentials ) except CredentialsValidateFailedError as ex: diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml index 1458b830cd..05cd402d4e 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml @@ -69,3 +69,12 @@ provider_credential_schema: label: en_US: AWS GovCloud (US-West) zh_Hans: AWS GovCloud (US-West) + - variable: model_for_validation + required: false + label: + en_US: Available Model Name + zh_Hans: 可用模型名称 + type: text-input + placeholder: + en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation. + zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 2ea65780f1..46f17fe19b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -656,6 +656,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if assistant_message_function_call: # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call + if delta_assistant_message_function_call_storage.arguments is None: + delta_assistant_message_function_call_storage.arguments = '' if not has_finish_reason: continue diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/__init__.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml new file mode 100644 index 0000000000..eed09f95de --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml @@ -0,0 +1,4 @@ +model: text-embedding-v1 +model_type: text-embedding +model_properties: + context_size: 2048 diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml new file mode 100644 index 0000000000..db2fa861e6 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml @@ -0,0 +1,4 @@ +model: text-embedding-v2 +model_type: text-embedding +model_properties: + context_size: 2048 diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py new file mode 100644 index 0000000000..a5f3660fb2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -0,0 +1,132 @@ +import time +from typing import Optional + +import dashscope + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from core.model_runtime.model_providers.tongyi._common import _CommonTongyi + + +class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): + """ + Model class for Tongyi text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + credentials_kwargs = self._to_credential_kwargs(credentials) + dashscope.api_key = credentials_kwargs["dashscope_api_key"] + embeddings, embedding_used_tokens = self.embed_documents(model, texts) + + return TextEmbeddingResult( + embeddings=embeddings, + usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + total_num_tokens = 0 + for text in texts: + total_num_tokens += self._get_num_tokens_by_gpt2(text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + dashscope.api_key = credentials_kwargs["dashscope_api_key"] + # call embedding model + self.embed_documents(model=model, texts=["ping"]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @staticmethod + def embed_documents(model: str, texts: list[str]) -> tuple[list[list[float]], int]: + """Call out to Tongyi's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text, and tokens usage. + """ + embeddings = [] + embedding_used_tokens = 0 + for text in texts: + response = dashscope.TextEmbedding.call(model=model, input=text, text_type="document") + data = response.output["embeddings"][0] + embeddings.append(data["embedding"]) + embedding_used_tokens += response.usage["total_tokens"] + + return [list(map(float, e)) for e in embeddings], embedding_used_tokens + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml index 500fd6e045..441d833f70 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml @@ -17,6 +17,7 @@ help: supported_model_types: - llm - tts + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index ca7b1c1f45..6b5bcc5bcf 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -32,3 +32,8 @@ parameter_rules: zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index a768902a77..ddea331c8e 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -30,3 +30,8 @@ parameter_rules: zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. required: false + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 diff --git a/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py new file mode 100644 index 0000000000..62d6faeb1d --- /dev/null +++ b/api/migrations/versions/a8f9b3c45e4a_add_tenant_id_db_index.py @@ -0,0 +1,36 @@ +"""add_tenant_id_db_index + +Revision ID: a8f9b3c45e4a +Revises: 16830a790f0f +Create Date: 2024-03-18 05:07:35.588473 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a8f9b3c45e4a' +down_revision = '16830a790f0f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.create_index('document_segment_tenant_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.create_index('document_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_index('document_tenant_idx') + + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_tenant_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/b289e2408ee2_add_workflow.py b/api/migrations/versions/b289e2408ee2_add_workflow.py index 8fadf2dc6c..473752d6f7 100644 --- a/api/migrations/versions/b289e2408ee2_add_workflow.py +++ b/api/migrations/versions/b289e2408ee2_add_workflow.py @@ -11,7 +11,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b289e2408ee2' -down_revision = '16830a790f0f' +down_revision = 'a8f9b3c45e4a' branch_labels = None depends_on = None diff --git a/api/models/dataset.py b/api/models/dataset.py index 94664bf49a..031bbe4dc7 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -176,6 +176,7 @@ class Document(db.Model): db.PrimaryKeyConstraint('id', name='document_pkey'), db.Index('document_dataset_id_idx', 'dataset_id'), db.Index('document_is_paused_idx', 'is_paused'), + db.Index('document_tenant_idx', 'tenant_id'), ) # initial fields @@ -334,6 +335,7 @@ class DocumentSegment(db.Model): db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), + db.Index('document_segment_tenant_idx', 'tenant_id'), ) # initial fields