mirror of
https://github.com/langgenius/dify.git
synced 2026-06-19 08:31:07 +08:00
fix: move remote credential validation outside DB session to prevent … (#35350)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
eaddd4a132
commit
a74e12809b
@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:param session: optional database session
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
# fix origin data
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if credential_record and credential_record.encrypted_config:
|
||||
if not credential_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {"openai_api_key": credential_record.encrypted_config}
|
||||
@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
credential_name = self._generate_provider_credential_name(pre_session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
try:
|
||||
new_record = ProviderCredential(
|
||||
@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.flush()
|
||||
|
||||
if not provider_record:
|
||||
# If provider record does not exist, create it
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
credential_name=credential_name, session=pre_session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
||||
credentials = self.validate_provider_credentials(
|
||||
credentials=credentials, credential_id=credential_id, session=session
|
||||
)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
json.loads(credential_record.encrypted_config)
|
||||
if credential_record and credential_record.encrypted_config
|
||||
@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
model=model, model_type=model_type, session=pre_session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
try:
|
||||
@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(credential)
|
||||
session.flush()
|
||||
|
||||
# save provider model
|
||||
if not provider_model_record:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
session=session,
|
||||
session=pre_session,
|
||||
exclude_id=credential_id,
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
|
||||
@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key")
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="encrypted-old-key"
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc::{value}",
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
|
||||
credential_id="credential-1",
|
||||
session=session,
|
||||
)
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc::{value}",
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
|
||||
credential_id="credential-1",
|
||||
)
|
||||
|
||||
assert validated["openai_api_key"] == "enc::restored-key"
|
||||
assert validated["region"] == "us"
|
||||
@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
|
||||
def test_validate_provider_credentials_without_credential_id() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
mock_session = Mock()
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.Session") as mock_session_cls:
|
||||
with patch("core.entities.provider_configuration.db") as mock_db:
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
|
||||
assert validated == {"region": "us"}
|
||||
mock_session_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None:
|
||||
@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non
|
||||
def test_validate_provider_credentials_handles_invalid_original_json() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="{invalid-json"
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-key"}
|
||||
|
||||
@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback(
|
||||
def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.model_credential_schema = _build_secret_model_schema()
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config='{"openai_api_key":"enc"}'
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
session = Mock()
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"region": "us"}
|
||||
with _patched_session(session):
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"region": "us"},
|
||||
)
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
mock_factory2 = Mock()
|
||||
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"region": "us"},
|
||||
)
|
||||
assert validated == {"region": "us"}
|
||||
|
||||
|
||||
@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
|
||||
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
|
||||
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.model_credential_schema = _build_secret_model_schema()
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="{invalid-json"
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user