fix: move remote credential validation outside DB session to prevent … (#35350)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
zyssyz123 2026-04-17 15:42:29 +08:00 committed by GitHub
parent eaddd4a132
commit a74e12809b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 156 additions and 177 deletions

View File

@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credential_name = self._generate_provider_credential_name(pre_session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
credentials = self.validate_provider_credentials(credentials=credentials)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
model=model, model_type=model_type, session=pre_session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=session,
session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:

View File

@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
]
)
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key")
mock_session = Mock()
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config="encrypted-old-key"
)
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
"core.entities.provider_configuration.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc::{value}",
):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
credential_id="credential-1",
session=session,
)
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
"core.entities.provider_configuration.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc::{value}",
):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
credential_id="credential-1",
)
assert validated["openai_api_key"] == "enc::restored-key"
assert validated["region"] == "us"
@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
def test_validate_provider_credentials_without_credential_id() -> None:
configuration = _build_provider_configuration()
mock_session = Mock()
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.Session") as mock_session_cls:
with patch("core.entities.provider_configuration.db") as mock_db:
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
mock_session_cls.assert_called_once()
def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None:
@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non
def test_validate_provider_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
mock_session = Mock()
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config="{invalid-json"
)
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
session=session,
)
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
)
assert validated == {"openai_api_key": "enc-key"}
@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback(
def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
mock_session = Mock()
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config='{"openai_api_key":"enc"}'
)
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
session=session,
)
assert validated == {"openai_api_key": "enc-new"}
session = Mock()
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"region": "us"}
with _patched_session(session):
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"region": "us"},
)
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
)
assert validated == {"openai_api_key": "enc-new"}
mock_factory2 = Mock()
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"region": "us"},
)
assert validated == {"region": "us"}
@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = None
mock_session = Mock()
mock_session.execute.return_value.scalar_one_or_none.return_value = None
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
session=session,
)
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
)
assert validated == {"openai_api_key": "enc-new"}
@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
mock_session = Mock()
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config="{invalid-json"
)
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
session=session,
)
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"openai_api_key": HIDDEN_VALUE},
credential_id="cred-1",
)
assert validated == {"openai_api_key": "enc-new"}