diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 34a27b954c..fdfceeb148 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -266,6 +266,7 @@ class PipelineGenerator(BaseAppGenerator): else: priority_rag_pipeline_run_task.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file.id, + tenant_id=dataset.tenant_id, ) # return batch, dataset, documents diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 41884661b2..1b5077df7b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -345,19 +345,18 @@ class DatasourceProviderService: def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: """ check if tenant oauth params is enabled - """ - with Session(db.engine).no_autoflush as session: - return ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - enabled=True, - ) - .count() - > 0 + """ + return ( + db.session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, + enabled=True, ) + .count() + > 0 + ) def get_tenant_oauth_client( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False @@ -365,23 +364,22 @@ class DatasourceProviderService: """ get tenant oauth client """ - with Session(db.engine).no_autoflush as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ) - .first() + tenant_oauth_client_params = ( + db.session.query(DatasourceOauthTenantParamConfig) + .filter_by( + tenant_id=tenant_id, + provider=datasource_provider_id.provider_name, + plugin_id=datasource_provider_id.plugin_id, ) - if tenant_oauth_client_params: - encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) - if mask: - return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) - else: - return encrypter.decrypt(tenant_oauth_client_params.client_params) - return None + .first() + ) + if tenant_oauth_client_params: + encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) + if mask: + return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params)) + else: + return encrypter.decrypt(tenant_oauth_client_params.client_params) + return None def get_oauth_encrypter( self, tenant_id: str, datasource_provider_id: DatasourceProviderID