Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN- 2025-09-03 13:53:43 +08:00
commit 8c97937cae
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
42 changed files with 1565 additions and 750 deletions

View File

@ -62,9 +62,6 @@ jobs:
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: Run Basedpyright Checks
run: dev/basedpyright-check
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env

View File

@ -43,6 +43,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example

2
.gitignore vendored
View File

@ -127,6 +127,8 @@ venv.bak/
.mypy_cache/
.dmypy.json
dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker
.pyre/

View File

@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
)
if args.get("config_from", "") == "predefined-model":
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -1,5 +1,6 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@ -823,7 +880,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -834,10 +891,14 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@ -881,7 +942,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@ -894,7 +955,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@ -926,8 +987,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@ -983,12 +1045,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -1055,6 +1112,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@ -1608,11 +1666,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@ -1634,6 +1690,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@ -1666,11 +1724,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED

View File

@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: Optional[bool] = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs

View File

@ -1,8 +1,9 @@
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
QuotaConfiguration,
QuotaUnit,
SystemConfiguration,
UnaddedModelConfiguration,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
@ -537,6 +539,23 @@ class ProviderManager:
for credential in available_credentials
]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -623,6 +642,44 @@ class ProviderManager:
:param provider_model_records: provider model records
:return:
"""
# Get custom provider configuration
custom_provider_configuration = self._get_custom_provider_configuration(
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
)
can_added_models = [
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
]
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations,
can_added_models=can_added_models,
)
def _get_custom_provider_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
) -> CustomProviderConfiguration | None:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
)
if not custom_provider_record:
return None
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
@ -630,113 +687,98 @@ class ProviderManager:
else []
)
# Get custom provider record
custom_provider_record = None
for provider_record in provider_records:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
# Get and decrypt provider credentials
provider_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=custom_provider_record.id,
encrypted_config=custom_provider_record.encrypted_config,
secret_variables=provider_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.PROVIDER,
is_provider=True,
)
custom_provider_record = provider_record
return CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
def _get_can_added_models(
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
) -> list[dict]:
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
# Get not added custom models credentials
not_added_custom_models_credentials = [
credential
for credential in all_model_credentials
if (credential.model_name, credential.model_type) not in existing_model_set
]
if not cached_provider_credentials:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Group credentials by model
model_to_credentials = defaultdict(list)
for credential in not_added_custom_models_credentials:
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds
],
}
for model_key, creds in model_to_credentials.items()
]
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
def _get_custom_model_configurations(
self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_model_records: list[ProviderModel],
can_added_models: list[dict],
all_model_credentials: Sequence[ProviderModelCredential],
) -> list[CustomModelConfiguration]:
"""Get custom model configurations."""
# Get model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
# Get custom provider model credentials
# Create credentials lookup for efficient access
credentials_map = defaultdict(list)
for credential in all_model_credentials:
credentials_map[(credential.model_name, credential.model_type)].append(credential)
custom_model_configurations = []
# Process existing model records
for provider_model_record in provider_model_records:
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
# Use pre-fetched credentials instead of individual database calls
available_model_credentials = [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in credentials_map.get(
(provider_model_record.model_name, provider_model_record.model_type), []
)
]
# Get and decrypt model credentials
provider_model_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=provider_model_record.id,
encrypted_config=provider_model_record.encrypted_config,
secret_variables=model_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.MODEL,
is_provider=False,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
@ -748,7 +790,71 @@ class ProviderManager:
)
)
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
# Add models that can be added
for model in can_added_models:
custom_model_configurations.append(
CustomModelConfiguration(
model=model["model"],
model_type=model["model_type"],
credentials=None,
current_credential_id=None,
current_credential_name=None,
available_model_credentials=model["available_model_credentials"],
unadded_to_model_list=True,
)
)
return custom_model_configurations
def _get_and_decrypt_credentials(
self,
tenant_id: str,
record_id: str,
encrypted_config: str | None,
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=record_id,
cache_type=cache_type,
)
# Try to get from cache first
cached_credentials = credentials_cache.get()
if cached_credentials:
return cached_credentials
# Parse encrypted config
if not encrypted_config:
return {}
if is_provider and not encrypted_config.startswith("{"):
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in secret_variables:
if variable in credentials:
with contextlib.suppress(ValueError):
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# Cache the decrypted credentials
credentials_cache.set(credentials=credentials)
return credentials
def _to_system_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
@ -956,18 +1062,6 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -1033,6 +1127,7 @@ class ProviderManager:
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
)
)

View File

@ -13,6 +13,7 @@ from core.entities.provider_entities import (
CustomModelConfiguration,
ProviderQuotaType,
QuotaConfiguration,
UnaddedModelConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel):
current_credential_name: Optional[str] = None
available_credentials: Optional[list[CredentialConfiguration]] = None
custom_models: Optional[list[CustomModelConfiguration]] = None
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
class SystemConfigurationResponse(BaseModel):

View File

@ -3,6 +3,8 @@ import logging
from json import JSONDecodeError
from typing import Optional, Union
from sqlalchemy import or_
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
@ -69,7 +71,7 @@ class ModelLoadBalancingService:
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
@ -100,6 +102,11 @@ class ModelLoadBalancingService:
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
else:
credential_source_type = "custom_model"
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
@ -108,6 +115,10 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
@ -405,7 +416,7 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name in {"__inherit__", "__delete__"}:
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
if credential_id:

View File

@ -72,6 +72,7 @@ class ModelProviderService:
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
provider_response = ProviderResponse(
tenant_id=tenant_id,
@ -95,6 +96,7 @@ class ModelProviderService:
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
@ -152,7 +154,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@ -172,7 +174,7 @@ class ModelProviderService:
provider: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@ -249,7 +251,7 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
create and save model credentials.
@ -278,7 +280,7 @@ class ModelProviderService:
model: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update model credentials.

View File

@ -6,4 +6,4 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/.."
# run basedpyright checks
uv --directory api run basedpyright
uv run --directory api --dev basedpyright

View File

@ -1,6 +1,7 @@
import {
isValidElement,
memo,
useCallback,
useMemo,
} from 'react'
import { RiExternalLinkLine } from '@remixicon/react'
@ -23,6 +24,7 @@ export type BaseFieldProps = {
formSchema: FormSchema
field: AnyFieldApi
disabled?: boolean
onChange?: (field: string, value: any) => void
}
const BaseField = ({
fieldClassName,
@ -32,6 +34,7 @@ const BaseField = ({
formSchema,
field,
disabled: propsDisabled,
onChange,
}: BaseFieldProps) => {
const renderI18nObject = useRenderI18nObject()
const {
@ -40,7 +43,6 @@ const BaseField = ({
placeholder,
options,
labelClassName: formLabelClassName,
show_on = [],
disabled: formSchemaDisabled,
} = formSchema
const disabled = propsDisabled || formSchemaDisabled
@ -90,21 +92,11 @@ const BaseField = ({
}) || []
}, [options, renderI18nObject, optionValues])
const value = useStore(field.form.store, s => s.values[field.name])
const values = useStore(field.form.store, (s) => {
return show_on.reduce((acc, condition) => {
acc[condition.variable] = s.values[condition.variable]
return acc
}, {} as Record<string, any>)
})
const show = useMemo(() => {
return show_on.every((condition) => {
const conditionValue = values[condition.variable]
return conditionValue === condition.value
})
}, [values, show_on])
if (!show)
return null
const handleChange = useCallback((value: any) => {
field.handleChange(value)
onChange?.(field.name, value)
}, [field, onChange])
return (
<div className={cn(fieldClassName)}>
@ -124,7 +116,9 @@ const BaseField = ({
name={field.name}
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onChange={(e) => {
handleChange(e.target.value)
}}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
@ -139,7 +133,7 @@ const BaseField = ({
type='password'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onChange={e => handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
@ -155,7 +149,7 @@ const BaseField = ({
type='number'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onChange={e => handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
@ -166,11 +160,14 @@ const BaseField = ({
formSchema.type === FormTypeEnum.select && (
<PureSelect
value={value}
onChange={v => field.handleChange(v)}
onChange={v => handleChange(v)}
disabled={disabled}
placeholder={memorizedPlaceholder}
options={memorizedOptions}
triggerPopupSameWidth
popupProps={{
className: 'max-h-[320px] overflow-y-auto',
}}
/>
)
}
@ -189,7 +186,7 @@ const BaseField = ({
disabled && 'cursor-not-allowed opacity-50',
inputClassName,
)}
onClick={() => !disabled && field.handleChange(option.value)}
onClick={() => !disabled && handleChange(option.value)}
>
{
formSchema.showRadioUI && (

View File

@ -8,7 +8,10 @@ import type {
AnyFieldApi,
AnyFormApi,
} from '@tanstack/react-form'
import { useForm } from '@tanstack/react-form'
import {
useForm,
useStore,
} from '@tanstack/react-form'
import type {
FormRef,
FormSchema,
@ -32,6 +35,7 @@ export type BaseFormProps = {
ref?: FormRef
disabled?: boolean
formFromProps?: AnyFormApi
onChange?: (field: string, value: any) => void
} & Pick<BaseFieldProps, 'fieldClassName' | 'labelClassName' | 'inputContainerClassName' | 'inputClassName'>
const BaseForm = ({
@ -45,6 +49,7 @@ const BaseForm = ({
ref,
disabled,
formFromProps,
onChange,
}: BaseFormProps) => {
const initialDefaultValues = useMemo(() => {
if (defaultValues)
@ -63,6 +68,19 @@ const BaseForm = ({
const { getFormValues } = useGetFormValues(form, formSchemas)
const { getValidators } = useGetValidators()
const showOnValues = useStore(form.store, (s: any) => {
const result: Record<string, any> = {}
formSchemas.forEach((schema) => {
const { show_on } = schema
if (show_on?.length) {
show_on.forEach((condition) => {
result[condition.variable] = s.values[condition.variable]
})
}
})
return result
})
useImperativeHandle(ref, () => {
return {
getForm() {
@ -87,19 +105,29 @@ const BaseForm = ({
inputContainerClassName={inputContainerClassName}
inputClassName={inputClassName}
disabled={disabled}
onChange={onChange}
/>
)
}
return null
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled])
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled, onChange])
const renderFieldWrapper = useCallback((formSchema: FormSchema) => {
const validators = getValidators(formSchema)
const {
name,
show_on = [],
} = formSchema
const show = show_on?.every((condition) => {
const conditionValue = showOnValues[condition.variable]
return conditionValue === condition.value
})
if (!show)
return null
return (
<form.Field
key={name}
@ -109,7 +137,7 @@ const BaseForm = ({
{renderField}
</form.Field>
)
}, [renderField, form, getValidators])
}, [renderField, form, getValidators, showOnValues])
if (!formSchemas?.length)
return null

View File

@ -199,6 +199,7 @@ export type CustomModelCredential = CustomModel & {
credentials?: Record<string, any>
available_model_credentials?: Credential[]
current_credential_id?: string
current_credential_name?: string
}
export type CredentialWithModel = Credential & {
@ -236,6 +237,10 @@ export type ModelProvider = {
current_credential_name?: string
available_credentials?: Credential[]
custom_models?: CustomModelCredential[]
can_added_models?: {
model: string
model_type: ModelTypeEnum
}[]
}
system_configuration: {
enabled: boolean
@ -323,3 +328,10 @@ export type ModelCredential = {
current_credential_id?: string
current_credential_name?: string
}
export enum ModelModalModeEnum {
configProviderCredential = 'config-provider-credential',
configCustomModel = 'config-custom-model',
addCustomModelToModelList = 'add-custom-model-to-model-list',
configModelCredential = 'config-model-credential',
}

View File

@ -13,6 +13,7 @@ import type {
DefaultModel,
DefaultModelResponse,
Model,
ModelModalModeEnum,
ModelProvider,
ModelTypeEnum,
} from './declarations'
@ -348,29 +349,31 @@ export const useRefreshModel = () => {
export const useModelModalHandler = () => {
const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
const { handleRefreshModel } = useRefreshModel()
return (
provider: ModelProvider,
configurationMethod: ConfigurationMethodEnum,
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
isModelCredential?: boolean,
credential?: Credential,
model?: CustomModel,
onUpdate?: () => void,
extra: {
isModelCredential?: boolean,
credential?: Credential,
model?: CustomModel,
onUpdate?: (newPayload: any, formValues?: Record<string, any>) => void,
mode?: ModelModalModeEnum,
} = {},
) => {
setShowModelModal({
payload: {
currentProvider: provider,
currentConfigurationMethod: configurationMethod,
currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields,
isModelCredential,
credential,
model,
isModelCredential: extra.isModelCredential,
credential: extra.credential,
model: extra.model,
mode: extra.mode,
},
onSaveCallback: () => {
handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields)
onUpdate?.()
onSaveCallback: (newPayload, formValues) => {
extra.onUpdate?.(newPayload, formValues)
},
})
}

View File

@ -1,7 +1,6 @@
import {
memo,
useCallback,
useMemo,
} from 'react'
import { RiAddLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
@ -9,20 +8,22 @@ import { Authorized } from '@/app/components/header/account-setting/model-provid
import cn from '@/utils/classnames'
import type {
Credential,
CustomConfigurationModelFixedFields,
CustomModelCredential,
ModelCredential,
ModelProvider,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Tooltip from '@/app/components/base/tooltip'
import { ConfigurationMethodEnum, ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
type AddCredentialInLoadBalancingProps = {
provider: ModelProvider
model: CustomModelCredential
configurationMethod: ConfigurationMethodEnum
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields
modelCredential: ModelCredential
onSelectCredential: (credential: Credential) => void
onUpdate?: () => void
onUpdate?: (payload?: any, formValues?: Record<string, any>) => void
onRemove?: (credentialId: string) => void
}
const AddCredentialInLoadBalancing = ({
provider,
@ -31,41 +32,17 @@ const AddCredentialInLoadBalancing = ({
modelCredential,
onSelectCredential,
onUpdate,
onRemove,
}: AddCredentialInLoadBalancingProps) => {
const { t } = useTranslation()
const {
available_credentials,
} = modelCredential
const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel
const isCustomModel = configurationMethod === ConfigurationMethodEnum.customizableModel
const notAllowCustomCredential = provider.allow_custom_token === false
const ButtonComponent = useMemo(() => {
const Item = (
<div className={cn(
'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover',
notAllowCustomCredential && 'cursor-not-allowed opacity-50',
)}>
<RiAddLine className='mr-2 h-4 w-4' />
{
customModel
? t('common.modelProvider.auth.addCredential')
: t('common.modelProvider.auth.addApiKey')
}
</div>
)
if (notAllowCustomCredential) {
return (
<Tooltip
asChild
popupContent={t('plugin.auth.credentialUnavailable')}
>
{Item}
</Tooltip>
)
}
return Item
}, [notAllowCustomCredential, t, customModel])
const handleUpdate = useCallback((payload?: any, formValues?: Record<string, any>) => {
onUpdate?.(payload, formValues)
}, [onUpdate])
const renderTrigger = useCallback((open?: boolean) => {
const Item = (
@ -74,40 +51,40 @@ const AddCredentialInLoadBalancing = ({
open && 'bg-state-base-hover',
)}>
<RiAddLine className='mr-2 h-4 w-4' />
{
customModel
? t('common.modelProvider.auth.addCredential')
: t('common.modelProvider.auth.addApiKey')
}
{t('common.modelProvider.auth.addCredential')}
</div>
)
return Item
}, [t, customModel])
if (!available_credentials?.length)
return ButtonComponent
}, [t, isCustomModel])
return (
<Authorized
provider={provider}
renderTrigger={renderTrigger}
authParams={{
isModelCredential: isCustomModel,
mode: ModelModalModeEnum.configModelCredential,
onUpdate: handleUpdate,
onRemove,
}}
triggerOnlyOpenModal={!available_credentials?.length && !notAllowCustomCredential}
items={[
{
title: customModel ? t('common.modelProvider.auth.modelCredentials') : t('common.modelProvider.auth.apiKeys'),
model: customModel ? model : undefined,
title: isCustomModel ? '' : t('common.modelProvider.auth.apiKeys'),
model: isCustomModel ? model : undefined,
credentials: available_credentials ?? [],
},
]}
showModelTitle={!isCustomModel}
configurationMethod={configurationMethod}
currentCustomConfigurationModelFixedFields={customModel ? {
currentCustomConfigurationModelFixedFields={isCustomModel ? {
__model_name: model.model,
__model_type: model.model_type,
} : undefined}
onItemClick={onSelectCredential}
placement='bottom-start'
onUpdate={onUpdate}
isModelCredential={customModel}
popupTitle={isCustomModel ? t('common.modelProvider.auth.modelCredentials') : ''}
/>
)
}

View File

@ -1,32 +1,39 @@
import {
memo,
useCallback,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
RiAddCircleFill,
RiAddLine,
} from '@remixicon/react'
import {
Button,
} from '@/app/components/base/button'
import type {
ConfigurationMethodEnum,
CustomConfigurationModelFixedFields,
ModelProvider,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Authorized from './authorized'
import {
useAuth,
useCustomModels,
} from './hooks'
import { ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import cn from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import ModelIcon from '../model-icon'
import { useCanAddedModels } from './hooks/use-custom-models'
import { useAuth } from './hooks/use-auth'
import Tooltip from '@/app/components/base/tooltip'
type AddCustomModelProps = {
provider: ModelProvider,
configurationMethod: ConfigurationMethodEnum,
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
open?: boolean
onOpenChange?: (open: boolean) => void
}
const AddCustomModel = ({
provider,
@ -34,44 +41,32 @@ const AddCustomModel = ({
currentCustomConfigurationModelFixedFields,
}: AddCustomModelProps) => {
const { t } = useTranslation()
const customModels = useCustomModels(provider)
const noModels = !customModels.length
const [open, setOpen] = useState(false)
const canAddedModels = useCanAddedModels(provider)
const noModels = !canAddedModels.length
const {
handleOpenModal,
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true)
handleOpenModal: handleOpenModalForAddNewCustomModel,
} = useAuth(
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
{
isModelCredential: true,
mode: ModelModalModeEnum.configCustomModel,
},
)
const {
handleOpenModal: handleOpenModalForAddCustomModelToModelList,
} = useAuth(
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
{
isModelCredential: true,
mode: ModelModalModeEnum.addCustomModelToModelList,
},
)
const notAllowCustomCredential = provider.allow_custom_token === false
const handleClick = useCallback(() => {
if (notAllowCustomCredential)
return
handleOpenModal()
}, [handleOpenModal, notAllowCustomCredential])
const ButtonComponent = useMemo(() => {
const Item = (
<Button
variant='ghost-accent'
size='small'
onClick={handleClick}
className={cn(
notAllowCustomCredential && 'cursor-not-allowed opacity-50',
)}
>
<RiAddCircleFill className='mr-1 h-3.5 w-3.5' />
{t('common.modelProvider.addModel')}
</Button>
)
if (notAllowCustomCredential) {
return (
<Tooltip
asChild
popupContent={t('plugin.auth.credentialUnavailable')}
>
{Item}
</Tooltip>
)
}
return Item
}, [handleClick, notAllowCustomCredential, t])
const renderTrigger = useCallback((open?: boolean) => {
const Item = (
@ -79,32 +74,93 @@ const AddCustomModel = ({
variant='ghost'
size='small'
className={cn(
'text-text-tertiary',
open && 'bg-components-button-ghost-bg-hover',
notAllowCustomCredential && !!noModels && 'cursor-not-allowed opacity-50',
)}
>
<RiAddCircleFill className='mr-1 h-3.5 w-3.5' />
{t('common.modelProvider.addModel')}
</Button>
)
if (notAllowCustomCredential && !!noModels) {
return (
<Tooltip asChild popupContent={t('plugin.auth.credentialUnavailable')}>
{Item}
</Tooltip>
)
}
return Item
}, [t])
if (noModels)
return ButtonComponent
}, [t, notAllowCustomCredential, noModels])
return (
<Authorized
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}
items={customModels.map(model => ({
model,
credentials: model.available_model_credentials ?? [],
}))}
renderTrigger={renderTrigger}
isModelCredential
enableAddModelCredential
bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')}
/>
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='bottom-end'
offset={{
mainAxis: 4,
crossAxis: 0,
}}
>
<PortalToFollowElemTrigger onClick={() => {
if (noModels) {
if (notAllowCustomCredential)
return
handleOpenModalForAddNewCustomModel()
return
}
setOpen(prev => !prev)
}}>
{renderTrigger(open)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[100]'>
<div className='w-[320px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg'>
<div className='max-h-[304px] overflow-y-auto p-1'>
{
canAddedModels.map(model => (
<div
key={model.model}
className='flex h-8 cursor-pointer items-center rounded-lg px-2 hover:bg-state-base-hover'
onClick={() => {
handleOpenModalForAddCustomModelToModelList(undefined, model)
setOpen(false)
}}
>
<ModelIcon
className='mr-1 h-5 w-5 shrink-0'
iconClassName='h-5 w-5'
provider={provider}
modelName={model.model}
/>
<div
className='system-md-regular grow truncate text-text-primary'
title={model.model}
>
{model.model}
</div>
</div>
))
}
</div>
{
!notAllowCustomCredential && (
<div
className='system-xs-medium flex cursor-pointer items-center border-t border-t-divider-subtle p-3 text-text-accent-light-mode-only'
onClick={() => {
handleOpenModalForAddNewCustomModel()
setOpen(false)
}}
>
<RiAddLine className='mr-1 h-4 w-4' />
{t('common.modelProvider.auth.addNewModel')}
</div>
)
}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}

View File

@ -2,18 +2,17 @@ import {
memo,
useCallback,
} from 'react'
import { RiAddLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import CredentialItem from './credential-item'
import type {
Credential,
CustomModel,
CustomModelCredential,
ModelProvider,
} from '../../declarations'
import Button from '@/app/components/base/button'
import Tooltip from '@/app/components/base/tooltip'
import ModelIcon from '../../model-icon'
type AuthorizedItemProps = {
provider: ModelProvider
model?: CustomModelCredential
title?: string
disabled?: boolean
@ -25,8 +24,12 @@ type AuthorizedItemProps = {
onItemClick?: (credential: Credential, model?: CustomModel) => void
enableAddModelCredential?: boolean
notAllowCustomCredential?: boolean
showModelTitle?: boolean
disableDeleteButShowAction?: boolean
disableDeleteTip?: string
}
export const AuthorizedItem = ({
provider,
model,
title,
credentials,
@ -36,10 +39,10 @@ export const AuthorizedItem = ({
showItemSelectedIcon,
selectedCredentialId,
onItemClick,
enableAddModelCredential,
notAllowCustomCredential,
showModelTitle,
disableDeleteButShowAction,
disableDeleteTip,
}: AuthorizedItemProps) => {
const { t } = useTranslation()
const handleEdit = useCallback((credential?: Credential) => {
onEdit?.(credential, model)
}, [onEdit, model])
@ -52,34 +55,29 @@ export const AuthorizedItem = ({
return (
<div className='p-1'>
<div
className='flex h-9 items-center'
>
<div className='h-5 w-5 shrink-0'></div>
<div
className='system-md-medium mx-1 grow truncate text-text-primary'
title={title ?? model?.model}
>
{title ?? model?.model}
</div>
{
enableAddModelCredential && !notAllowCustomCredential && (
<Tooltip
asChild
popupContent={t('common.modelProvider.auth.addModelCredential')}
{
showModelTitle && (
<div
className='flex h-9 items-center px-2'
>
{
model?.model && (
<ModelIcon
className='mr-1 h-5 w-5 shrink-0'
provider={provider}
modelName={model.model}
/>
)
}
<div
className='system-md-medium mx-1 grow truncate text-text-primary'
title={title ?? model?.model}
>
<Button
className='h-6 w-6 shrink-0 rounded-full p-0'
size='small'
variant='secondary-accent'
onClick={() => handleEdit?.()}
>
<RiAddLine className='h-4 w-4' />
</Button>
</Tooltip>
)
}
</div>
{title ?? model?.model}
</div>
</div>
)
}
{
credentials.map(credential => (
<CredentialItem
@ -91,6 +89,8 @@ export const AuthorizedItem = ({
showSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredentialId}
onItemClick={handleItemClick}
disableDeleteButShowAction={disableDeleteButShowAction}
disableDeleteTip={disableDeleteTip}
/>
))
}

View File

@ -24,6 +24,8 @@ type CredentialItemProps = {
disableRename?: boolean
disableEdit?: boolean
disableDelete?: boolean
disableDeleteButShowAction?: boolean
disableDeleteTip?: string
showSelectedIcon?: boolean
selectedCredentialId?: string
}
@ -36,6 +38,8 @@ const CredentialItem = ({
disableRename,
disableEdit,
disableDelete,
disableDeleteButShowAction,
disableDeleteTip,
showSelectedIcon,
selectedCredentialId,
}: CredentialItemProps) => {
@ -43,6 +47,9 @@ const CredentialItem = ({
const showAction = useMemo(() => {
return !(disableRename && disableEdit && disableDelete)
}, [disableRename, disableEdit, disableDelete])
const disableDeleteWhenSelected = useMemo(() => {
return disableDeleteButShowAction && selectedCredentialId === credential.credential_id
}, [disableDeleteButShowAction, selectedCredentialId, credential.credential_id])
const Item = (
<div
@ -104,16 +111,21 @@ const CredentialItem = ({
}
{
!disableDelete && !credential.from_enterprise && (
<Tooltip popupContent={t('common.operation.delete')}>
<Tooltip popupContent={disableDeleteWhenSelected ? disableDeleteTip : t('common.operation.delete')}>
<ActionButton
className='hover:bg-transparent'
disabled={disabled}
onClick={(e) => {
if (disabled || disableDeleteWhenSelected)
return
e.stopPropagation()
onDelete?.(credential)
}}
>
<RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' />
<RiDeleteBinLine className={cn(
'h-4 w-4 text-text-tertiary',
!disableDeleteWhenSelected && 'hover:text-text-destructive',
disableDeleteWhenSelected && 'opacity-50',
)} />
</ActionButton>
</Tooltip>
)

View File

@ -1,12 +1,11 @@
import {
Fragment,
memo,
useCallback,
useMemo,
useState,
} from 'react'
import {
RiAddLine,
RiEqualizer2Line,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import {
@ -25,6 +24,7 @@ import type {
Credential,
CustomConfigurationModelFixedFields,
CustomModel,
ModelModalModeEnum,
ModelProvider,
} from '../../declarations'
import { useAuth } from '../hooks'
@ -34,15 +34,20 @@ type AuthorizedProps = {
provider: ModelProvider,
configurationMethod: ConfigurationMethodEnum,
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
isModelCredential?: boolean
authParams?: {
isModelCredential?: boolean
onUpdate?: (newPayload?: any, formValues?: Record<string, any>) => void
onRemove?: (credentialId: string) => void
mode?: ModelModalModeEnum
}
items: {
title?: string
model?: CustomModel
selectedCredential?: Credential
credentials: Credential[]
}[]
selectedCredential?: Credential
disabled?: boolean
renderTrigger?: (open?: boolean) => React.ReactNode
renderTrigger: (open?: boolean) => React.ReactNode
isOpen?: boolean
onOpenChange?: (open: boolean) => void
offset?: PortalToFollowElemOptions['offset']
@ -50,18 +55,22 @@ type AuthorizedProps = {
triggerPopupSameWidth?: boolean
popupClassName?: string
showItemSelectedIcon?: boolean
onUpdate?: () => void
onItemClick?: (credential: Credential, model?: CustomModel) => void
enableAddModelCredential?: boolean
bottomAddModelCredentialText?: string
triggerOnlyOpenModal?: boolean
hideAddAction?: boolean
disableItemClick?: boolean
popupTitle?: string
showModelTitle?: boolean
disableDeleteButShowAction?: boolean
disableDeleteTip?: string
}
const Authorized = ({
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
items,
isModelCredential,
selectedCredential,
authParams,
disabled,
renderTrigger,
isOpen,
@ -71,10 +80,14 @@ const Authorized = ({
triggerPopupSameWidth = false,
popupClassName,
showItemSelectedIcon,
onUpdate,
onItemClick,
enableAddModelCredential,
bottomAddModelCredentialText,
triggerOnlyOpenModal,
hideAddAction,
disableItemClick,
popupTitle,
showModelTitle,
disableDeleteButShowAction,
disableDeleteTip,
}: AuthorizedProps) => {
const { t } = useTranslation()
const [isLocalOpen, setIsLocalOpen] = useState(false)
@ -85,6 +98,12 @@ const Authorized = ({
setIsLocalOpen(open)
}, [onOpenChange])
const {
isModelCredential,
onUpdate,
onRemove,
mode,
} = authParams || {}
const {
openConfirmDelete,
closeConfirmDelete,
@ -93,7 +112,17 @@ const Authorized = ({
handleConfirmDelete,
deleteCredentialId,
handleOpenModal,
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate)
} = useAuth(
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
{
isModelCredential,
onUpdate,
onRemove,
mode,
},
)
const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => {
handleOpenModal(credential, model)
@ -101,28 +130,18 @@ const Authorized = ({
}, [handleOpenModal, setMergedIsOpen])
const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => {
if (disableItemClick)
return
if (onItemClick)
onItemClick(credential, model)
else
handleActiveCredential(credential, model)
setMergedIsOpen(false)
}, [handleActiveCredential, onItemClick, setMergedIsOpen])
}, [handleActiveCredential, onItemClick, setMergedIsOpen, disableItemClick])
const notAllowCustomCredential = provider.allow_custom_token === false
const Trigger = useMemo(() => {
const Item = (
<Button
className='grow'
size='small'
>
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
{t('common.operation.config')}
</Button>
)
return Item
}, [t])
return (
<>
<PortalToFollowElem
@ -134,44 +153,60 @@ const Authorized = ({
>
<PortalToFollowElemTrigger
onClick={() => {
if (triggerOnlyOpenModal) {
handleOpenModal()
return
}
setMergedIsOpen(!mergedIsOpen)
}}
asChild
>
{
renderTrigger
? renderTrigger(mergedIsOpen)
: Trigger
}
{renderTrigger(mergedIsOpen)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[100]'>
<div className={cn(
'w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg',
'w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-[5px]',
popupClassName,
)}>
{
popupTitle && (
<div className='system-xs-medium px-3 pb-0.5 pt-[10px] text-text-tertiary'>
{popupTitle}
</div>
)
}
<div className='max-h-[304px] overflow-y-auto'>
{
items.map((item, index) => (
<AuthorizedItem
key={index}
title={item.title}
model={item.model}
credentials={item.credentials}
disabled={disabled}
onDelete={openConfirmDelete}
onEdit={handleEdit}
showItemSelectedIcon={showItemSelectedIcon}
selectedCredentialId={selectedCredential?.credential_id}
onItemClick={handleItemClick}
enableAddModelCredential={enableAddModelCredential}
notAllowCustomCredential={notAllowCustomCredential}
/>
<Fragment key={index}>
<AuthorizedItem
provider={provider}
title={item.title}
model={item.model}
credentials={item.credentials}
disabled={disabled}
onDelete={openConfirmDelete}
disableDeleteButShowAction={disableDeleteButShowAction}
disableDeleteTip={disableDeleteTip}
onEdit={handleEdit}
showItemSelectedIcon={showItemSelectedIcon}
selectedCredentialId={item.selectedCredential?.credential_id}
onItemClick={handleItemClick}
showModelTitle={showModelTitle}
/>
{
index !== items.length - 1 && (
<div className='h-[1px] bg-divider-subtle'></div>
)
}
</Fragment>
))
}
</div>
<div className='h-[1px] bg-divider-subtle'></div>
{
isModelCredential && !notAllowCustomCredential && (
isModelCredential && !notAllowCustomCredential && !hideAddAction && (
<div
onClick={() => handleEdit(
undefined,
@ -182,15 +217,15 @@ const Authorized = ({
}
: undefined,
)}
className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only'
className='system-xs-medium flex h-[40px] cursor-pointer items-center px-3 text-text-accent-light-mode-only'
>
<RiAddLine className='mr-1 h-4 w-4' />
{bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')}
{t('common.modelProvider.auth.addModelCredential')}
</div>
)
}
{
!isModelCredential && !notAllowCustomCredential && (
!isModelCredential && !notAllowCustomCredential && !hideAddAction && (
<div className='p-2'>
<Button
onClick={() => handleEdit()}

View File

@ -25,7 +25,7 @@ const ConfigModel = ({
if (loadBalancingInvalid) {
return (
<div
className='system-2xs-medium-uppercase relative flex h-[18px] items-center rounded-[5px] border border-text-warning bg-components-badge-bg-dimm px-1.5 text-text-warning'
className='system-2xs-medium-uppercase relative flex h-[18px] cursor-pointer items-center rounded-[5px] border border-text-warning bg-components-badge-bg-dimm px-1.5 text-text-warning'
onClick={onClick}
>
<RiScales3Line className='mr-0.5 h-3 w-3' />

View File

@ -1,7 +1,6 @@
import {
memo,
useCallback,
useMemo,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
@ -16,24 +15,18 @@ import type {
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Authorized from './authorized'
import { useAuth, useCredentialStatus } from './hooks'
import { useCredentialStatus } from './hooks'
import Tooltip from '@/app/components/base/tooltip'
import cn from '@/utils/classnames'
type ConfigProviderProps = {
provider: ModelProvider,
configurationMethod: ConfigurationMethodEnum,
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
}
const ConfigProvider = ({
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
}: ConfigProviderProps) => {
const { t } = useTranslation()
const {
handleOpenModal,
} = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields)
const {
hasCredential,
authorized,
@ -42,23 +35,20 @@ const ConfigProvider = ({
available_credentials,
} = useCredentialStatus(provider)
const notAllowCustomCredential = provider.allow_custom_token === false
const handleClick = useCallback(() => {
if (!hasCredential && !notAllowCustomCredential)
handleOpenModal()
}, [handleOpenModal, hasCredential, notAllowCustomCredential])
const ButtonComponent = useMemo(() => {
const renderTrigger = useCallback(() => {
const Item = (
<Button
className={cn('grow', notAllowCustomCredential && 'cursor-not-allowed opacity-50')}
className='grow'
size='small'
onClick={handleClick}
variant={!authorized ? 'secondary-accent' : 'secondary'}
>
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
{t('common.operation.setup')}
{hasCredential && t('common.operation.config')}
{!hasCredential && t('common.operation.setup')}
</Button>
)
if (notAllowCustomCredential) {
if (notAllowCustomCredential && !hasCredential) {
return (
<Tooltip
asChild
@ -69,26 +59,27 @@ const ConfigProvider = ({
)
}
return Item
}, [handleClick, authorized, notAllowCustomCredential, t])
if (!hasCredential)
return ButtonComponent
}, [authorized, hasCredential, notAllowCustomCredential, t])
return (
<Authorized
provider={provider}
configurationMethod={ConfigurationMethodEnum.predefinedModel}
currentCustomConfigurationModelFixedFields={currentCustomConfigurationModelFixedFields}
items={[
{
title: t('common.modelProvider.auth.apiKeys'),
credentials: available_credentials ?? [],
selectedCredential: {
credential_id: current_credential_id ?? '',
credential_name: current_credential_name ?? '',
},
},
]}
selectedCredential={{
credential_id: current_credential_id ?? '',
credential_name: current_credential_name ?? '',
}}
showItemSelectedIcon
showModelTitle
renderTrigger={renderTrigger}
triggerOnlyOpenModal={!hasCredential && !notAllowCustomCredential}
/>
)
}

View File

@ -0,0 +1,115 @@
import {
memo,
useCallback,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
RiAddLine,
RiArrowDownSLine,
} from '@remixicon/react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import type { Credential } from '@/app/components/header/account-setting/model-provider-page/declarations'
import CredentialItem from './authorized/credential-item'
import Badge from '@/app/components/base/badge'
import Indicator from '@/app/components/header/indicator'
type CredentialSelectorProps = {
selectedCredential?: Credential & { addNewCredential?: boolean }
credentials: Credential[]
onSelect: (credential: Credential & { addNewCredential?: boolean }) => void
disabled?: boolean
notAllowAddNewCredential?: boolean
}
const CredentialSelector = ({
selectedCredential,
credentials,
onSelect,
disabled,
notAllowAddNewCredential,
}: CredentialSelectorProps) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const handleSelect = useCallback((credential: Credential & { addNewCredential?: boolean }) => {
setOpen(false)
onSelect(credential)
}, [onSelect])
const handleAddNewCredential = useCallback(() => {
handleSelect({
credential_id: '__add_new_credential',
addNewCredential: true,
credential_name: t('common.modelProvider.auth.addNewModelCredential'),
})
}, [handleSelect, t])
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
triggerPopupSameWidth
>
<PortalToFollowElemTrigger asChild onClick={() => !disabled && setOpen(v => !v)}>
<div className='system-sm-regular flex h-8 w-full items-center justify-between rounded-lg bg-components-input-bg-normal px-2'>
{
selectedCredential && (
<div className='flex items-center'>
{
!selectedCredential.addNewCredential && <Indicator className='ml-1 mr-2 shrink-0' />
}
<div className='system-sm-regular truncate text-components-input-text-filled' title={selectedCredential.credential_name}>{selectedCredential.credential_name}</div>
{
selectedCredential.from_enterprise && (
<Badge className='shrink-0'>Enterprise</Badge>
)
}
</div>
)
}
{
!selectedCredential && (
<div className='system-sm-regular grow truncate text-components-input-text-placeholder'>{t('common.modelProvider.auth.selectModelCredential')}</div>
)
}
<RiArrowDownSLine className='h-4 w-4 text-text-quaternary' />
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[100]'>
<div className='border-ccomponents-panel-border rounded-xl border-[0.5px] bg-components-panel-bg-blur shadow-lg'>
<div className='max-h-[320px] overflow-y-auto p-1'>
{
credentials.map(credential => (
<CredentialItem
key={credential.credential_id}
credential={credential}
disableDelete
disableEdit
disableRename
onItemClick={handleSelect}
showSelectedIcon
selectedCredentialId={selectedCredential?.credential_id}
/>
))
}
</div>
{
!notAllowAddNewCredential && (
<div
className='system-xs-medium flex h-10 cursor-pointer items-center border-t border-t-divider-subtle px-7 text-text-accent-light-mode-only'
onClick={handleAddNewCredential}
>
<RiAddLine className='mr-1 h-4 w-4' />
{t('common.modelProvider.auth.addNewModelCredential')}
</div>
)
}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default memo(CredentialSelector)

View File

@ -17,7 +17,7 @@ import type {
export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => {
const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId)
const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom)
const modelData = useGetModelCredential(!!isModelCredential && (!!credentialId || !!model), provider, credentialId, model?.model, model?.model_type, configFrom)
return isModelCredential ? modelData : providerData
}

View File

@ -11,20 +11,32 @@ import type {
Credential,
CustomConfigurationModelFixedFields,
CustomModel,
ModelModalModeEnum,
ModelProvider,
} from '../../declarations'
import {
useModelModalHandler,
useRefreshModel,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useDeleteModel } from '@/service/use-models'
export const useAuth = (
provider: ModelProvider,
configurationMethod: ConfigurationMethodEnum,
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
isModelCredential?: boolean,
onUpdate?: () => void,
extra: {
isModelCredential?: boolean,
onUpdate?: (newPayload?: any, formValues?: Record<string, any>) => void,
onRemove?: (credentialId: string) => void,
mode?: ModelModalModeEnum,
} = {},
) => {
const {
isModelCredential,
onUpdate,
onRemove,
mode,
} = extra
const { t } = useTranslation()
const { notify } = useToastContext()
const {
@ -33,22 +45,30 @@ export const useAuth = (
getEditCredentialService,
getAddCredentialService,
} = useAuthService(provider.provider)
const { mutateAsync: deleteModelService } = useDeleteModel(provider.provider)
const handleOpenModelModal = useModelModalHandler()
const { handleRefreshModel } = useRefreshModel()
const pendingOperationCredentialId = useRef<string | null>(null)
const pendingOperationModel = useRef<CustomModel | null>(null)
const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null)
const handleSetDeleteCredentialId = useCallback((credentialId: string | null) => {
setDeleteCredentialId(credentialId)
pendingOperationCredentialId.current = credentialId
}, [])
const pendingOperationModel = useRef<CustomModel | null>(null)
const [deleteModel, setDeleteModel] = useState<CustomModel | null>(null)
const handleSetDeleteModel = useCallback((model: CustomModel | null) => {
setDeleteModel(model)
pendingOperationModel.current = model
}, [])
const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => {
if (credential)
pendingOperationCredentialId.current = credential.credential_id
handleSetDeleteCredentialId(credential.credential_id)
if (model)
pendingOperationModel.current = model
setDeleteCredentialId(pendingOperationCredentialId.current)
handleSetDeleteModel(model)
}, [])
const closeConfirmDelete = useCallback(() => {
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
handleSetDeleteCredentialId(null)
handleSetDeleteModel(null)
}, [])
const [doingAction, setDoingAction] = useState(false)
const doingActionRef = useRef(doingAction)
@ -70,45 +90,49 @@ export const useAuth = (
type: 'success',
message: t('common.api.actionSuccess'),
})
onUpdate?.()
handleRefreshModel(provider, configurationMethod, undefined)
}
finally {
handleSetDoingAction(false)
}
}, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction])
}, [getActiveCredentialService, notify, t, handleSetDoingAction])
const handleConfirmDelete = useCallback(async () => {
if (doingActionRef.current)
return
if (!pendingOperationCredentialId.current) {
setDeleteCredentialId(null)
if (!pendingOperationCredentialId.current && !pendingOperationModel.current) {
closeConfirmDelete()
return
}
try {
handleSetDoingAction(true)
await getDeleteCredentialService(!!isModelCredential)({
credential_id: pendingOperationCredentialId.current,
model: pendingOperationModel.current?.model,
model_type: pendingOperationModel.current?.model_type,
})
let payload: any = {}
if (pendingOperationCredentialId.current) {
payload = {
credential_id: pendingOperationCredentialId.current,
model: pendingOperationModel.current?.model,
model_type: pendingOperationModel.current?.model_type,
}
await getDeleteCredentialService(!!isModelCredential)(payload)
}
if (!pendingOperationCredentialId.current && pendingOperationModel.current) {
payload = {
model: pendingOperationModel.current.model,
model_type: pendingOperationModel.current.model_type,
}
await deleteModelService(payload)
}
notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
onUpdate?.()
handleRefreshModel(provider, configurationMethod, undefined)
setDeleteCredentialId(null)
pendingOperationCredentialId.current = null
pendingOperationModel.current = null
onRemove?.(pendingOperationCredentialId.current ?? '')
closeConfirmDelete()
}
finally {
handleSetDoingAction(false)
}
}, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential])
const handleAddCredential = useCallback((model?: CustomModel) => {
if (model)
pendingOperationModel.current = model
}, [])
}, [notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential, closeConfirmDelete, handleRefreshModel, provider, configurationMethod, deleteModelService])
const handleSaveCredential = useCallback(async (payload: Record<string, any>) => {
if (doingActionRef.current)
return
@ -123,24 +147,35 @@ export const useAuth = (
if (res.result === 'success') {
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
onUpdate?.()
handleRefreshModel(provider, configurationMethod, undefined)
}
}
finally {
handleSetDoingAction(false)
}
}, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService])
}, [notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService])
const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => {
handleOpenModelModal(
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
isModelCredential,
credential,
model,
onUpdate,
{
isModelCredential,
credential,
model,
onUpdate,
mode,
},
)
}, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate])
}, [
handleOpenModelModal,
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
isModelCredential,
onUpdate,
mode,
])
return {
pendingOperationCredentialId,
@ -150,8 +185,8 @@ export const useAuth = (
doingAction,
handleActiveCredential,
handleConfirmDelete,
handleAddCredential,
deleteCredentialId,
deleteModel,
handleSaveCredential,
handleOpenModal,
}

View File

@ -7,3 +7,9 @@ export const useCustomModels = (provider: ModelProvider) => {
return custom_models || []
}
export const useCanAddedModels = (provider: ModelProvider) => {
const { can_added_models } = provider.custom_configuration
return can_added_models || []
}

View File

@ -3,7 +3,6 @@ import { useTranslation } from 'react-i18next'
import type {
Credential,
CustomModelCredential,
ModelLoadBalancingConfig,
ModelProvider,
} from '../../declarations'
import {
@ -18,7 +17,6 @@ export const useModelFormSchemas = (
credentials?: Record<string, any>,
credential?: Credential,
model?: CustomModelCredential,
draftConfig?: ModelLoadBalancingConfig,
) => {
const { t } = useTranslation()
const {
@ -27,26 +25,15 @@ export const useModelFormSchemas = (
model_credential_schema,
} = provider
const formSchemas = useMemo(() => {
const modelTypeSchema = genModelTypeFormSchema(supported_model_types)
const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model)
if (!!model) {
modelTypeSchema.disabled = true
modelNameSchema.disabled = true
}
return providerFormSchemaPredefined
? provider_credential_schema.credential_form_schemas
: [
modelTypeSchema,
modelNameSchema,
...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas),
]
: model_credential_schema.credential_form_schemas
}, [
providerFormSchemaPredefined,
provider_credential_schema?.credential_form_schemas,
supported_model_types,
model_credential_schema?.credential_form_schemas,
model_credential_schema?.model,
draftConfig?.enabled,
model,
])
@ -55,7 +42,7 @@ export const useModelFormSchemas = (
type: FormTypeEnum.textInput,
variable: '__authorization_name__',
label: t('plugin.auth.authorizationName'),
required: true,
required: false,
}
return [
@ -79,8 +66,33 @@ export const useModelFormSchemas = (
return result
}, [credentials, credential, model, formSchemas])
const modelNameAndTypeFormSchemas = useMemo(() => {
if (providerFormSchemaPredefined)
return []
const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model)
const modelTypeSchema = genModelTypeFormSchema(supported_model_types)
return [
modelNameSchema,
modelTypeSchema,
]
}, [supported_model_types, model_credential_schema?.model, providerFormSchemaPredefined])
const modelNameAndTypeFormValues = useMemo(() => {
let result = {}
if (providerFormSchemaPredefined)
return result
if (model)
result = { ...result, __model_name: model?.model, __model_type: model?.model_type }
return result
}, [model, providerFormSchemaPredefined])
return {
formSchemas: formSchemasWithAuthorizationName,
formValues,
modelNameAndTypeFormSchemas,
modelNameAndTypeFormValues,
}
}

View File

@ -4,3 +4,5 @@ export { default as AddCredentialInLoadBalancing } from './add-credential-in-loa
export { default as AddCustomModel } from './add-custom-model'
export { default as ConfigProvider } from './config-provider'
export { default as ConfigModel } from './config-model'
export { default as ManageCustomModelCredentials } from './manage-custom-model-credentials'
export { default as CredentialSelector } from './credential-selector'

View File

@ -0,0 +1,82 @@
import {
memo,
useCallback,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
Button,
} from '@/app/components/base/button'
import type {
CustomConfigurationModelFixedFields,
ModelProvider,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import {
ConfigurationMethodEnum,
ModelModalModeEnum,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import Authorized from './authorized'
import {
useCustomModels,
} from './hooks'
import cn from '@/utils/classnames'
type ManageCustomModelCredentialsProps = {
provider: ModelProvider,
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
}
const ManageCustomModelCredentials = ({
provider,
currentCustomConfigurationModelFixedFields,
}: ManageCustomModelCredentialsProps) => {
const { t } = useTranslation()
const customModels = useCustomModels(provider)
const noModels = !customModels.length
const renderTrigger = useCallback((open?: boolean) => {
const Item = (
<Button
variant='ghost'
size='small'
className={cn(
'mr-0.5 text-text-tertiary',
open && 'bg-components-button-ghost-bg-hover',
)}
>
{t('common.modelProvider.auth.manageCredentials')}
</Button>
)
return Item
}, [t])
if (noModels)
return null
return (
<Authorized
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}
currentCustomConfigurationModelFixedFields={currentCustomConfigurationModelFixedFields}
items={customModels.map(model => ({
model,
credentials: model.available_model_credentials ?? [],
selectedCredential: model.current_credential_id ? {
credential_id: model.current_credential_id,
credential_name: model.current_credential_name,
} : undefined,
}))}
renderTrigger={renderTrigger}
authParams={{
isModelCredential: true,
mode: ModelModalModeEnum.configModelCredential,
}}
hideAddAction
disableItemClick
popupTitle={t('common.modelProvider.auth.customModelCredentials')}
showModelTitle
disableDeleteButShowAction
disableDeleteTip={t('common.modelProvider.auth.customModelCredentialsDeleteTip')}
/>
)
}
export default memo(ManageCustomModelCredentials)

View File

@ -13,7 +13,7 @@ import type {
CustomModel,
ModelProvider,
} from '../declarations'
import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ConfigurationMethodEnum, ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import cn from '@/utils/classnames'
import Tooltip from '@/app/components/base/tooltip'
import Badge from '@/app/components/base/badge'
@ -24,6 +24,8 @@ type SwitchCredentialInLoadBalancingProps = {
credentials?: Credential[]
customModelCredential?: Credential
setCustomModelCredential: Dispatch<SetStateAction<Credential | undefined>>
onUpdate?: (payload?: any, formValues?: Record<string, any>) => void
onRemove?: (credentialId: string) => void
}
const SwitchCredentialInLoadBalancing = ({
provider,
@ -31,6 +33,8 @@ const SwitchCredentialInLoadBalancing = ({
customModelCredential,
setCustomModelCredential,
credentials,
onUpdate,
onRemove,
}: SwitchCredentialInLoadBalancingProps) => {
const { t } = useTranslation()
@ -94,27 +98,31 @@ const SwitchCredentialInLoadBalancing = ({
<Authorized
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}
currentCustomConfigurationModelFixedFields={model ? {
__model_name: model.model,
__model_type: model.model_type,
} : undefined}
authParams={{
isModelCredential: true,
mode: ModelModalModeEnum.configModelCredential,
onUpdate,
onRemove,
}}
items={[
{
title: t('common.modelProvider.auth.modelCredentials'),
model,
credentials: credentials || [],
selectedCredential: customModelCredential ? {
credential_id: customModelCredential?.credential_id || '',
credential_name: customModelCredential?.credential_name || '',
} : undefined,
},
]}
renderTrigger={renderTrigger}
onItemClick={handleItemClick}
isModelCredential
enableAddModelCredential
bottomAddModelCredentialText={t('common.modelProvider.auth.addModelCredential')}
selectedCredential={
customModelCredential
? {
credential_id: customModelCredential?.credential_id || '',
credential_name: customModelCredential?.credential_name || '',
}
: undefined
}
showItemSelectedIcon
popupTitle={t('common.modelProvider.auth.modelCredentials')}
/>
)
}

View File

@ -5,6 +5,7 @@ import {
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import { RiCloseLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
@ -15,6 +16,7 @@ import type {
import {
ConfigurationMethodEnum,
FormTypeEnum,
ModelModalModeEnum,
} from '../declarations'
import {
useLanguage,
@ -46,16 +48,19 @@ import {
import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon'
import Badge from '@/app/components/base/badge'
import { useRenderI18nObject } from '@/hooks/use-i18n'
import { CredentialSelector } from '../model-auth'
type ModelModalProps = {
provider: ModelProvider
configurateMethod: ConfigurationMethodEnum
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields
onCancel: () => void
onSave: () => void
onSave: (formValues?: Record<string, any>) => void
onRemove: (formValues?: Record<string, any>) => void
model?: CustomModel
credential?: Credential
isModelCredential?: boolean
mode?: ModelModalModeEnum
}
const ModelModal: FC<ModelModalProps> = ({
@ -67,6 +72,7 @@ const ModelModal: FC<ModelModalProps> = ({
model,
credential,
isModelCredential,
mode = ModelModalModeEnum.configProviderCredential,
}) => {
const renderI18nObject = useRenderI18nObject()
const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel
@ -81,40 +87,88 @@ const ModelModal: FC<ModelModalProps> = ({
closeConfirmDelete,
openConfirmDelete,
doingAction,
} = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave)
handleActiveCredential,
} = useAuth(
provider,
configurateMethod,
currentCustomConfigurationModelFixedFields,
{
isModelCredential,
mode,
},
)
const {
credentials: formSchemasValue,
available_credentials,
} = credentialData as any
const { isCurrentWorkspaceManager } = useAppContext()
const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager
const { t } = useTranslation()
const language = useLanguage()
const {
formSchemas,
formValues,
modelNameAndTypeFormSchemas,
modelNameAndTypeFormValues,
} = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model)
const formRef = useRef<FormRefObject>(null)
const formRef1 = useRef<FormRefObject>(null)
const [selectedCredential, setSelectedCredential] = useState<Credential & { addNewCredential?: boolean } | undefined>()
const formRef2 = useRef<FormRefObject>(null)
const isEditMode = !!Object.keys(formValues).filter((key) => {
return key !== '__model_name' && key !== '__model_type'
}).length && isCurrentWorkspaceManager
const handleSave = useCallback(async () => {
if (mode === ModelModalModeEnum.addCustomModelToModelList && selectedCredential && !selectedCredential?.addNewCredential) {
handleActiveCredential(selectedCredential, model)
onCancel()
return
}
let modelNameAndTypeIsCheckValidated = true
let modelNameAndTypeValues: Record<string, any> = {}
if (mode === ModelModalModeEnum.configCustomModel) {
const formResult = formRef1.current?.getFormValues({
needCheckValidatedValues: true,
}) || { isCheckValidated: false, values: {} }
modelNameAndTypeIsCheckValidated = formResult.isCheckValidated
modelNameAndTypeValues = formResult.values
}
if (mode === ModelModalModeEnum.configModelCredential && model) {
modelNameAndTypeValues = {
__model_name: model.model,
__model_type: model.model_type,
}
}
if (mode === ModelModalModeEnum.addCustomModelToModelList && selectedCredential?.addNewCredential && model) {
modelNameAndTypeValues = {
__model_name: model.model,
__model_type: model.model_type,
}
}
const {
isCheckValidated,
values,
} = formRef.current?.getFormValues({
} = formRef2.current?.getFormValues({
needCheckValidatedValues: true,
needTransformWhenSecretFieldIsPristine: true,
}) || { isCheckValidated: false, values: {} }
if (!isCheckValidated)
if (!isCheckValidated || !modelNameAndTypeIsCheckValidated)
return
const {
__authorization_name__,
__model_name,
__model_type,
} = modelNameAndTypeValues
const {
__authorization_name__,
...rest
} = values
if (__model_name && __model_type) {
handleSaveCredential({
if (__model_name && __model_type && __authorization_name__) {
await handleSaveCredential({
credential_id: credential?.credential_id,
credentials: rest,
name: __authorization_name__,
@ -123,41 +177,33 @@ const ModelModal: FC<ModelModalProps> = ({
})
}
else {
handleSaveCredential({
await handleSaveCredential({
credential_id: credential?.credential_id,
credentials: rest,
name: __authorization_name__,
})
}
}, [handleSaveCredential, credential?.credential_id, model])
onSave(values)
}, [handleSaveCredential, credential?.credential_id, model, onSave, mode, selectedCredential, handleActiveCredential])
const modalTitle = useMemo(() => {
if (!providerFormSchemaPredefined && !model) {
return (
<div className='flex items-center'>
<ModelIcon
className='mr-2 h-10 w-10 shrink-0'
iconClassName='h-10 w-10'
provider={provider}
/>
<div>
<div className='system-xs-medium-uppercase text-text-tertiary'>{t('common.modelProvider.auth.apiKeyModal.addModel')}</div>
<div className='system-md-semibold text-text-primary'>{renderI18nObject(provider.label)}</div>
</div>
</div>
)
}
let label = t('common.modelProvider.auth.apiKeyModal.title')
if (model)
label = t('common.modelProvider.auth.addModelCredential')
if (mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.addCustomModelToModelList)
label = t('common.modelProvider.auth.addModel')
if (mode === ModelModalModeEnum.configModelCredential) {
if (credential)
label = t('common.modelProvider.auth.editModelCredential')
else
label = t('common.modelProvider.auth.addModelCredential')
}
return (
<div className='title-2xl-semi-bold text-text-primary'>
{label}
</div>
)
}, [providerFormSchemaPredefined, t, model, renderI18nObject])
}, [t, mode, credential])
const modalDesc = useMemo(() => {
if (providerFormSchemaPredefined) {
@ -172,7 +218,18 @@ const ModelModal: FC<ModelModalProps> = ({
}, [providerFormSchemaPredefined, t])
const modalModel = useMemo(() => {
if (model) {
if (mode === ModelModalModeEnum.configCustomModel) {
return (
<div className='mt-2 flex items-center'>
<ModelIcon
className='mr-2 h-4 w-4 shrink-0'
provider={provider}
/>
<div className='system-md-regular mr-1 text-text-secondary'>{renderI18nObject(provider.label)}</div>
</div>
)
}
if (model && (mode === ModelModalModeEnum.configModelCredential || mode === ModelModalModeEnum.addCustomModelToModelList)) {
return (
<div className='mt-2 flex items-center'>
<ModelIcon
@ -187,7 +244,38 @@ const ModelModal: FC<ModelModalProps> = ({
}
return null
}, [model, provider])
}, [model, provider, mode, renderI18nObject])
const showCredentialLabel = useMemo(() => {
if (mode === ModelModalModeEnum.configCustomModel)
return true
if (mode === ModelModalModeEnum.addCustomModelToModelList)
return selectedCredential?.addNewCredential
}, [mode, selectedCredential])
const showCredentialForm = useMemo(() => {
if (mode !== ModelModalModeEnum.addCustomModelToModelList)
return true
return selectedCredential?.addNewCredential
}, [mode, selectedCredential])
const saveButtonText = useMemo(() => {
if (mode === ModelModalModeEnum.addCustomModelToModelList || mode === ModelModalModeEnum.configCustomModel)
return t('common.operation.add')
return t('common.operation.save')
}, [mode, t])
const handleDeleteCredential = useCallback(() => {
handleConfirmDelete()
onCancel()
}, [handleConfirmDelete])
const handleModelNameAndTypeChange = useCallback((field: string, value: any) => {
const {
getForm,
} = formRef2.current as FormRefObject || {}
if (getForm())
getForm()?.setFieldValue(field, value)
}, [])
const notAllowCustomCredential = provider.allow_custom_token === false
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
@ -214,100 +302,132 @@ const ModelModal: FC<ModelModalProps> = ({
>
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
</div>
<div className='px-6 pt-6'>
<div className='pb-3'>
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className='max-h-[calc(100vh-320px)] overflow-y-auto'>
{
isLoading && (
<div className='flex items-center justify-center'>
<Loading />
</div>
)
}
{
!isLoading && (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName='justify-start'
ref={formRef}
/>
)
}
</div>
<div className='sticky bottom-0 -mx-2 mt-2 flex flex-wrap items-center justify-between gap-y-2 bg-components-panel-bg px-2 pb-6 pt-4'>
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target='_blank' rel='noopener noreferrer'
className='inline-flex items-center text-xs text-primary-600'
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className='ml-1 h-3 w-3' />
</a>
)
: <div />
}
<div>
{
isEditMode && (
<Button
variant='warning'
size='large'
className='mr-2'
onClick={() => openConfirmDelete(credential, model)}
>
{t('common.operation.remove')}
</Button>
)
}
<Button
size='large'
className='mr-2'
onClick={onCancel}
>
{t('common.operation.cancel')}
</Button>
<Button
size='large'
variant='primary'
onClick={handleSave}
disabled={isLoading || doingAction}
>
{t('common.operation.save')}
</Button>
</div>
</div>
<div className='p-6 pb-3'>
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className='border-t-[0.5px] border-t-divider-regular'>
<div className='flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary'>
<Lock01 className='mr-1 h-3 w-3 text-text-tertiary' />
{t('common.modelProvider.encrypted.front')}
<a
className='mx-1 text-text-accent'
target='_blank' rel='noopener noreferrer'
href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html'
<div className='max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3'>
{
mode === ModelModalModeEnum.configCustomModel && (
<AuthForm
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
}
}) as FormSchema[]}
defaultValues={modelNameAndTypeFormValues}
inputClassName='justify-start'
ref={formRef1}
onChange={handleModelNameAndTypeChange}
/>
)
}
{
mode === ModelModalModeEnum.addCustomModelToModelList && (
<CredentialSelector
credentials={available_credentials || []}
onSelect={setSelectedCredential}
selectedCredential={selectedCredential}
disabled={isLoading}
notAllowAddNewCredential={notAllowCustomCredential}
/>
)
}
{
showCredentialLabel && (
<div className='system-xs-medium-uppercase mb-3 mt-6 flex items-center text-text-tertiary'>
{t('common.modelProvider.auth.modelCredential')}
<div className='ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent' />
</div>
)
}
{
isLoading && (
<div className='mt-3 flex items-center justify-center'>
<Loading />
</div>
)
}
{
!isLoading
&& showCredentialForm
&& (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName='justify-start'
ref={formRef2}
/>
)
}
</div>
<div className='flex justify-between p-6 pt-5'>
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target='_blank' rel='noopener noreferrer'
className='system-xs-regular mt-2 inline-flex items-center text-text-accent'
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className='ml-1 h-3 w-3' />
</a>
)
: <div />
}
<div className='flex items-center justify-end space-x-2'>
{
isEditMode && (
<Button
variant='warning'
onClick={() => openConfirmDelete(credential, model)}
>
{t('common.operation.remove')}
</Button>
)
}
<Button
onClick={onCancel}
>
PKCS1_OAEP
</a>
{t('common.modelProvider.encrypted.back')}
{t('common.operation.cancel')}
</Button>
<Button
variant='primary'
onClick={handleSave}
disabled={isLoading || doingAction}
>
{saveButtonText}
</Button>
</div>
</div>
{
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
<div className='border-t-[0.5px] border-t-divider-regular'>
<div className='flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary'>
<Lock01 className='mr-1 h-3 w-3 text-text-tertiary' />
{t('common.modelProvider.encrypted.front')}
<a
className='mx-1 text-text-accent'
target='_blank' rel='noopener noreferrer'
href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html'
>
PKCS1_OAEP
</a>
{t('common.modelProvider.encrypted.back')}
</div>
</div>
)
}
</div>
{
deleteCredentialId && (
@ -316,7 +436,7 @@ const ModelModal: FC<ModelModalProps> = ({
title={t('common.modelProvider.confirmDelete')}
isDisabled={doingAction}
onCancel={closeConfirmDelete}
onConfirm={handleConfirmDelete}
onConfirm={handleDeleteCredential}
/>
)
}

View File

@ -111,7 +111,6 @@ const CredentialPanel = ({
<div className='flex items-center gap-0.5'>
<ConfigProvider
provider={provider}
configurationMethod={ConfigurationMethodEnum.predefinedModel}
/>
{
systemConfig.enabled && isCustomConfigured && (

View File

@ -25,7 +25,10 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import cn from '@/utils/classnames'
import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth'
import {
AddCustomModel,
ManageCustomModelCredentials,
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
type ProviderAddedCardProps = {
@ -155,10 +158,17 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
)}
{
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
<AddCustomModel
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}
/>
<div className='flex grow justify-end'>
<ManageCustomModelCredentials
provider={provider}
currentCustomConfigurationModelFixedFields={undefined}
/>
<AddCustomModel
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}
currentCustomConfigurationModelFixedFields={undefined}
/>
</div>
)
}
</div>

View File

@ -16,7 +16,10 @@ import {
import ModelListItem from './model-list-item'
import { useModalContextSelector } from '@/context/modal-context'
import { useAppContext } from '@/context/app-context'
import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth'
import {
AddCustomModel,
ManageCustomModelCredentials,
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
type ModelListProps = {
provider: ModelProvider
@ -67,6 +70,10 @@ const ModelList: FC<ModelListProps> = ({
{
isConfigurable && isCurrentWorkspaceManager && (
<div className='flex grow justify-end'>
<ManageCustomModelCredentials
provider={provider}
currentCustomConfigurationModelFixedFields={undefined}
/>
<AddCustomModel
provider={provider}
configurationMethod={ConfigurationMethodEnum.customizableModel}

View File

@ -2,8 +2,7 @@ import type { Dispatch, SetStateAction } from 'react'
import { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiDeleteBinLine,
RiEqualizer2Line,
RiIndeterminateCircleLine,
} from '@remixicon/react'
import type {
Credential,
@ -28,7 +27,6 @@ import GridMask from '@/app/components/base/grid-mask'
import { useProviderContextSelector } from '@/context/provider-context'
import { IS_CE_EDITION } from '@/config'
import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth'
import { useModelModalHandler } from '@/app/components/header/account-setting/model-provider-page/hooks'
import Badge from '@/app/components/base/badge/index'
export type ModelLoadBalancingConfigsProps = {
@ -40,7 +38,8 @@ export type ModelLoadBalancingConfigsProps = {
withSwitch?: boolean
className?: string
modelCredential: ModelCredential
onUpdate?: () => void
onUpdate?: (payload?: any, formValues?: Record<string, any>) => void
onRemove?: (credentialId: string) => void
model: CustomModelCredential
}
@ -55,11 +54,11 @@ const ModelLoadBalancingConfigs = ({
className,
modelCredential,
onUpdate,
onRemove,
}: ModelLoadBalancingConfigsProps) => {
const { t } = useTranslation()
const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel
const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled)
const handleOpenModal = useModelModalHandler()
const updateConfigEntry = useCallback(
(
@ -130,6 +129,17 @@ const ModelLoadBalancingConfigs = ({
return draftConfig.configs
}, [draftConfig])
const handleUpdate = useCallback((payload?: any, formValues?: Record<string, any>) => {
onUpdate?.(payload, formValues)
}, [onUpdate])
const handleRemove = useCallback((credentialId: string) => {
const index = draftConfig?.configs.findIndex(item => item.credential_id === credentialId && item.name !== '__inherit__')
if (index && index > -1)
updateConfigEntry(index, () => undefined)
onRemove?.(credentialId)
}, [draftConfig?.configs, updateConfigEntry, onRemove])
if (!draftConfig)
return null
@ -190,7 +200,7 @@ const ModelLoadBalancingConfigs = ({
</Tooltip>
)}
</div>
<div className='mr-1 text-[13px]'>
<div className='mr-1 text-[13px] text-text-secondary'>
{isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name}
</div>
{isProviderManaged && providerFormSchemaPredefined && (
@ -206,34 +216,14 @@ const ModelLoadBalancingConfigs = ({
{!isProviderManaged && (
<>
<div className='flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'>
{
config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && (
<span
className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover'
onClick={() => {
handleOpenModal(
provider,
configurationMethod,
currentCustomConfigurationModelFixedFields,
configurationMethod === ConfigurationMethodEnum.customizableModel,
(config.credential_id && config.name) ? {
credential_id: config.credential_id,
credential_name: config.name,
} : undefined,
model,
)
}}
>
<RiEqualizer2Line className='h-4 w-4' />
</span>
)
}
<span
className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover'
onClick={() => updateConfigEntry(index, () => undefined)}
>
<RiDeleteBinLine className='h-4 w-4' />
</span>
<Tooltip popupContent={t('common.operation.remove')}>
<span
className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover'
onClick={() => updateConfigEntry(index, () => undefined)}
>
<RiIndeterminateCircleLine className='h-4 w-4' />
</span>
</Tooltip>
</div>
</>
)}
@ -261,7 +251,8 @@ const ModelLoadBalancingConfigs = ({
configurationMethod={configurationMethod}
modelCredential={modelCredential}
onSelectCredential={addConfigEntry}
onUpdate={onUpdate}
onUpdate={handleUpdate}
onRemove={handleRemove}
/>
</div>
)}

View File

@ -2,6 +2,7 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import type {
Credential,
CustomConfigurationModelFixedFields,
ModelItem,
ModelLoadBalancingConfig,
ModelLoadBalancingConfigEntry,
@ -24,10 +25,14 @@ import {
useGetModelCredential,
useUpdateModelLoadBalancingConfig,
} from '@/service/use-models'
import { useAuth } from '../model-auth/hooks/use-auth'
import Confirm from '@/app/components/base/confirm'
import { useRefreshModel } from '../hooks'
export type ModelLoadBalancingModalProps = {
provider: ModelProvider
configurateMethod: ConfigurationMethodEnum
currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields
model: ModelItem
credential?: Credential
open?: boolean
@ -39,6 +44,7 @@ export type ModelLoadBalancingModalProps = {
const ModelLoadBalancingModal = ({
provider,
configurateMethod,
currentCustomConfigurationModelFixedFields,
model,
credential,
open = false,
@ -47,7 +53,20 @@ const ModelLoadBalancingModal = ({
}: ModelLoadBalancingModalProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const {
doingAction,
deleteModel,
openConfirmDelete,
closeConfirmDelete,
handleConfirmDelete,
} = useAuth(
provider,
configurateMethod,
currentCustomConfigurationModelFixedFields,
{
isModelCredential: true,
},
)
const [loading, setLoading] = useState(false)
const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel
const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model'
@ -121,6 +140,7 @@ const ModelLoadBalancingModal = ({
}
}, [current_credential_id, current_credential_name])
const [customModelCredential, setCustomModelCredential] = useState<Credential | undefined>(initialCustomModelCredential)
const { handleRefreshModel } = useRefreshModel()
const handleSave = async () => {
try {
setLoading(true)
@ -139,6 +159,7 @@ const ModelLoadBalancingModal = ({
)
if (res.result === 'success') {
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
handleRefreshModel(provider, configurateMethod, currentCustomConfigurationModelFixedFields)
onSave?.(provider.provider)
onClose?.()
}
@ -147,120 +168,208 @@ const ModelLoadBalancingModal = ({
setLoading(false)
}
}
const handleDeleteModel = useCallback(async () => {
await handleConfirmDelete()
onClose?.()
}, [handleConfirmDelete, onClose])
const handleUpdate = useCallback(async (payload?: any, formValues?: Record<string, any>) => {
const result = await refetch()
const available_credentials = result.data?.available_credentials || []
const credentialName = formValues?.__authorization_name__
const modelCredential = payload?.credential
if (!available_credentials.length) {
onClose?.()
return
}
if (!modelCredential) {
const currentCredential = available_credentials.find(c => c.credential_name === credentialName)
if (currentCredential) {
setDraftConfig((prev: any) => {
if (!prev)
return prev
return {
...prev,
configs: [...prev.configs, {
credential_id: currentCredential.credential_id,
enabled: true,
name: currentCredential.credential_name,
}],
}
})
}
}
else {
setDraftConfig((prev) => {
if (!prev)
return prev
const newConfigs = [...prev.configs]
const prevIndex = newConfigs.findIndex(item => item.credential_id === modelCredential.credential_id && item.name !== '__inherit__')
const newIndex = available_credentials.findIndex(c => c.credential_id === modelCredential.credential_id)
if (newIndex > -1 && prevIndex > -1)
newConfigs[prevIndex].name = available_credentials[newIndex].credential_name || ''
return {
...prev,
configs: newConfigs,
}
})
}
}, [refetch, credential])
const handleUpdateWhenSwitchCredential = useCallback(async () => {
const result = await refetch()
const available_credentials = result.data?.available_credentials || []
if (!available_credentials.length)
onClose?.()
}, [refetch, onClose])
return (
<Modal
isShow={Boolean(model) && open}
onClose={onClose}
className='w-[640px] max-w-none px-8 pt-8'
title={
<div className='pb-3 font-semibold'>
<div className='h-[30px]'>{
draftConfig?.enabled
? t('common.modelProvider.auth.configLoadBalancing')
: t('common.modelProvider.auth.configModel')
}</div>
{Boolean(model) && (
<div className='flex h-5 items-center'>
<ModelIcon
className='mr-2 shrink-0'
provider={provider}
modelName={model!.model}
/>
<ModelName
className='system-md-regular grow text-text-secondary'
modelItem={model!}
showModelType
showMode
showContextSize
/>
</div>
)}
</div>
}
>
{!draftConfig
? <Loading type='area' />
: (
<>
<div className='py-2'>
<div
className={classNames(
'min-h-16 rounded-xl border bg-components-panel-bg transition-colors',
draftConfig.enabled ? 'cursor-pointer border-components-panel-border' : 'cursor-default border-util-colors-blue-blue-600',
)}
onClick={draftConfig.enabled ? () => toggleModalBalancing(false) : undefined}
>
<div className='flex select-none items-center gap-2 px-[15px] py-3'>
<div className='flex h-8 w-8 shrink-0 grow-0 items-center justify-center rounded-lg border border-components-card-border bg-components-card-bg'>
{Boolean(model) && (
<ModelIcon className='shrink-0' provider={provider} modelName={model!.model} />
)}
</div>
<div className='grow'>
<div className='text-sm text-text-secondary'>{
providerFormSchemaPredefined
? t('common.modelProvider.auth.providerManaged')
: t('common.modelProvider.auth.specifyModelCredential')
}</div>
<div className='text-xs text-text-tertiary'>{
providerFormSchemaPredefined
? t('common.modelProvider.auth.providerManagedTip')
: t('common.modelProvider.auth.specifyModelCredentialTip')
}</div>
<>
<Modal
isShow={Boolean(model) && open}
onClose={onClose}
className='w-[640px] max-w-none px-8 pt-8'
title={
<div className='pb-3 font-semibold'>
<div className='h-[30px]'>{
draftConfig?.enabled
? t('common.modelProvider.auth.configLoadBalancing')
: t('common.modelProvider.auth.configModel')
}</div>
{Boolean(model) && (
<div className='flex h-5 items-center'>
<ModelIcon
className='mr-2 shrink-0'
provider={provider}
modelName={model!.model}
/>
<ModelName
className='system-md-regular grow text-text-secondary'
modelItem={model!}
showModelType
showMode
showContextSize
/>
</div>
)}
</div>
}
>
{!draftConfig
? <Loading type='area' />
: (
<>
<div className='py-2'>
<div
className={classNames(
'min-h-16 rounded-xl border bg-components-panel-bg transition-colors',
draftConfig.enabled ? 'cursor-pointer border-components-panel-border' : 'cursor-default border-util-colors-blue-blue-600',
)}
onClick={draftConfig.enabled ? () => toggleModalBalancing(false) : undefined}
>
<div className='flex select-none items-center gap-2 px-[15px] py-3'>
<div className='flex h-8 w-8 shrink-0 grow-0 items-center justify-center rounded-lg border border-components-card-border bg-components-card-bg'>
{Boolean(model) && (
<ModelIcon className='shrink-0' provider={provider} modelName={model!.model} />
)}
</div>
<div className='grow'>
<div className='text-sm text-text-secondary'>{
providerFormSchemaPredefined
? t('common.modelProvider.auth.providerManaged')
: t('common.modelProvider.auth.specifyModelCredential')
}</div>
<div className='text-xs text-text-tertiary'>{
providerFormSchemaPredefined
? t('common.modelProvider.auth.providerManagedTip')
: t('common.modelProvider.auth.specifyModelCredentialTip')
}</div>
</div>
{
!providerFormSchemaPredefined && (
<SwitchCredentialInLoadBalancing
provider={provider}
customModelCredential={customModelCredential ?? initialCustomModelCredential}
setCustomModelCredential={setCustomModelCredential}
model={model}
credentials={available_credentials}
onUpdate={handleUpdateWhenSwitchCredential}
onRemove={handleUpdateWhenSwitchCredential}
/>
)
}
</div>
</div>
{
modelCredential && (
<ModelLoadBalancingConfigs {...{
draftConfig,
setDraftConfig,
provider,
currentCustomConfigurationModelFixedFields: {
__model_name: model.model,
__model_type: model.model_type,
},
configurationMethod: model.fetch_from,
className: 'mt-2',
modelCredential,
onUpdate: handleUpdate,
onRemove: handleUpdateWhenSwitchCredential,
model: {
model: model.model,
model_type: model.model_type,
},
}} />
)
}
</div>
<div className='mt-6 flex items-center justify-between gap-2'>
<div>
{
!providerFormSchemaPredefined && (
<SwitchCredentialInLoadBalancing
provider={provider}
customModelCredential={initialCustomModelCredential ?? customModelCredential}
setCustomModelCredential={setCustomModelCredential}
model={model}
credentials={available_credentials}
/>
<Button
onClick={() => openConfirmDelete(undefined, { model: model.model, model_type: model.model_type })}
className='text-components-button-destructive-secondary-text'
>
{t('common.modelProvider.auth.removeModel')}
</Button>
)
}
</div>
<div className='space-x-2'>
<Button onClick={onClose}>{t('common.operation.cancel')}</Button>
<Button
variant='primary'
onClick={handleSave}
disabled={
loading
|| (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2)
|| isLoading
}
>{t('common.operation.save')}</Button>
</div>
</div>
{
modelCredential && (
<ModelLoadBalancingConfigs {...{
draftConfig,
setDraftConfig,
provider,
currentCustomConfigurationModelFixedFields: {
__model_name: model.model,
__model_type: model.model_type,
},
configurationMethod: model.fetch_from,
className: 'mt-2',
modelCredential,
onUpdate: refetch,
model: {
model: model.model,
model_type: model.model_type,
},
}} />
)
}
</div>
<div className='mt-6 flex items-center justify-end gap-2'>
<Button onClick={onClose}>{t('common.operation.cancel')}</Button>
<Button
variant='primary'
onClick={handleSave}
disabled={
loading
|| (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2)
|| isLoading
}
>{t('common.operation.save')}</Button>
</div>
</>
</>
)
}
</Modal >
{
deleteModel && (
<Confirm
isShow
title={t('common.modelProvider.confirmDelete')}
onCancel={closeConfirmDelete}
onConfirm={handleDeleteModel}
isDisabled={doingAction}
/>
)
}
</Modal >
</>
)
}

View File

@ -161,7 +161,7 @@ export const modelTypeFormat = (modelType: ModelTypeEnum) => {
export const genModelTypeFormSchema = (modelTypes: ModelTypeEnum[]) => {
return {
type: FormTypeEnum.radio,
type: FormTypeEnum.select,
label: {
zh_Hans: '模型类型',
en_US: 'Model Type',

View File

@ -9,7 +9,6 @@ import type {
Credential,
CustomConfigurationModelFixedFields,
CustomModel,
ModelLoadBalancingConfigEntry,
ModelProvider,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import {
@ -29,6 +28,7 @@ import { removeSpecificQueryParam } from '@/utils'
import { noop } from 'lodash-es'
import dynamic from 'next/dynamic'
import type { ExpireNoticeModalPayloadProps } from '@/app/education-apply/expire-notice-modal'
import type { ModelModalModeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const AccountSetting = dynamic(() => import('@/app/components/header/account-setting'), {
ssr: false,
@ -71,8 +71,8 @@ const ExpireNoticeModal = dynamic(() => import('@/app/education-apply/expire-not
export type ModalState<T> = {
payload: T
onCancelCallback?: () => void
onSaveCallback?: (newPayload: T) => void
onRemoveCallback?: (newPayload: T) => void
onSaveCallback?: (newPayload?: T, formValues?: Record<string, any>) => void
onRemoveCallback?: (newPayload?: T, formValues?: Record<string, any>) => void
onEditCallback?: (newPayload: T) => void
onValidateBeforeSaveCallback?: (newPayload: T) => boolean
isEditMode?: boolean
@ -86,10 +86,7 @@ export type ModelModalType = {
isModelCredential?: boolean
credential?: Credential
model?: CustomModel
}
export type LoadBalancingEntryModalType = ModelModalType & {
entry?: ModelLoadBalancingConfigEntry
index?: number
mode?: ModelModalModeEnum
}
export type ModalContextState = {
@ -187,9 +184,15 @@ export const ModalContextProvider = ({
showModelModal.onCancelCallback()
}, [showModelModal])
const handleSaveModelModal = useCallback(() => {
const handleSaveModelModal = useCallback((formValues?: Record<string, any>) => {
if (showModelModal?.onSaveCallback)
showModelModal.onSaveCallback(showModelModal.payload)
showModelModal.onSaveCallback(showModelModal.payload, formValues)
setShowModelModal(null)
}, [showModelModal])
const handleRemoveModelModal = useCallback((formValues?: Record<string, any>) => {
if (showModelModal?.onRemoveCallback)
showModelModal.onRemoveCallback(showModelModal.payload, formValues)
setShowModelModal(null)
}, [showModelModal])
@ -329,8 +332,10 @@ export const ModalContextProvider = ({
isModelCredential={showModelModal.payload.isModelCredential}
credential={showModelModal.payload.credential}
model={showModelModal.payload.model}
mode={showModelModal.payload.mode}
onCancel={handleCancelModelModal}
onSave={handleSaveModelModal}
onRemove={handleRemoveModelModal}
/>
)
}

View File

@ -498,10 +498,13 @@ const translation = {
authRemoved: 'Auth removed',
apiKeys: 'API Keys',
addApiKey: 'Add API Key',
addModel: 'Add model',
addNewModel: 'Add new model',
addCredential: 'Add credential',
addModelCredential: 'Add model credential',
editModelCredential: 'Edit model credential',
modelCredentials: 'Model credentials',
modelCredential: 'Model credential',
configModel: 'Config model',
configLoadBalancing: 'Config Load Balancing',
authorizationError: 'Authorization error',
@ -514,6 +517,12 @@ const translation = {
desc: 'After configuring credentials, all members within the workspace can use this model when orchestrating applications.',
addModel: 'Add model',
},
manageCredentials: 'Manage Credentials',
customModelCredentials: 'Custom Model Credentials',
addNewModelCredential: 'Add new model credential',
removeModel: 'Remove Model',
selectModelCredential: 'Select a model credential',
customModelCredentialsDeleteTip: 'Credential is in use and cannot be deleted',
},
},
dataSource: {

View File

@ -492,10 +492,13 @@ const translation = {
authRemoved: '授权已移除',
apiKeys: 'API 密钥',
addApiKey: '添加 API 密钥',
addModel: '添加模型',
addNewModel: '添加新模型',
addCredential: '添加凭据',
addModelCredential: '添加模型凭据',
editModelCredential: '编辑模型凭据',
modelCredentials: '模型凭据',
modelCredential: '模型凭据',
configModel: '配置模型',
configLoadBalancing: '配置负载均衡',
authorizationError: '授权错误',
@ -508,6 +511,12 @@ const translation = {
desc: '配置凭据后,工作空间中的所有成员都可以在编排应用时使用此模型。',
addModel: '添加模型',
},
manageCredentials: '管理凭据',
customModelCredentials: '自定义模型凭据',
addNewModelCredential: '添加模型新凭据',
removeModel: '移除模型',
selectModelCredential: '选择模型凭据',
customModelCredentialsDeleteTip: '模型凭据正在使用中,无法删除',
},
},
dataSource: {

View File

@ -122,7 +122,7 @@ export const useDeleteModel = (provider: string) => {
mutationFn: (data: {
model: string
model_type: ModelTypeEnum
}) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, {
}) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, {
body: data,
}),
})