From 16dff9d2c35f8ffb2bbfb78b4c46bc0b8ab00b4b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 14 Jun 2024 00:50:15 -0700 Subject: [PATCH] Fix init encryption master key (#2554) * fix init master key Signed-off-by: Yaliang Wu (cherry picked from commit 487f33a2e35e642429e9a3ea1eb0d715d542ea9f) --- .../ml/engine/encryptor/EncryptorImpl.java | 82 ++++- .../engine/encryptor/EncryptorImplTest.java | 348 +++++++++++++++++- .../opensearch/ml/cluster/MLSyncUpCron.java | 2 + .../ml/plugin/MachineLearningPlugin.java | 4 +- 4 files changed, 412 insertions(+), 24 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 617c6871e5..1c02a7f915 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -8,29 +8,36 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; +import java.time.Instant; import java.util.Base64; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import javax.crypto.spec.SecretKeySpec; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.DocWriteRequest; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.jce.JceMasterKey; +import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -42,11 +49,13 @@ public class EncryptorImpl implements Encryptor { private ClusterService clusterService; private Client client; private volatile String masterKey; + private MLIndicesHandler mlIndicesHandler; - public EncryptorImpl(ClusterService clusterService, Client client) { + public EncryptorImpl(ClusterService clusterService, Client client, MLIndicesHandler mlIndicesHandler) { this.masterKey = null; this.clusterService = clusterService; this.client = client; + this.mlIndicesHandler = mlIndicesHandler; } public EncryptorImpl(String masterKey) { @@ -104,28 +113,68 @@ private void initMasterKey() { AtomicReference exceptionRef = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); - if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { + mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { - if (r.isExists()) { - String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); - this.masterKey = masterKey; + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse == null || !getResponse.isExists()) { + IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + final String generatedMasterKey = generateMasterKey(); + indexRequest + .source(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + this.masterKey = generatedMasterKey; + log.info("ML encryption master key initialized successfully"); + latch.countDown(); + }, e -> { + + if (ExceptionUtils.getRootCause(e) instanceof VersionConflictEngineException) { + GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKeyResponse -> { + if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) { + final String masterKey = (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY); + this.masterKey = masterKey; + log.info("ML encryption master key already initialized, no action needed"); + latch.countDown(); + } else { + exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); + latch.countDown(); + } + }, error -> { + log.debug("Failed to get ML encryption master key", e); + exceptionRef.set(error); + latch.countDown(); + })); + } + } else { + log.debug("Failed to index ML encryption master key", e); + exceptionRef.set(e); + latch.countDown(); + } + })); } else { - exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + this.masterKey = masterKey; + log.info("ML encryption master key already initialized, no action needed"); + latch.countDown(); } }, e -> { - log.error("Failed to get ML encryption master key", e); + log.debug("Failed to get ML encryption master key from config index", e); exceptionRef.set(e); - }), latch), () -> context.restore())); + latch.countDown(); + })); } - } else { - exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); + }, e -> { + log.debug("Failed to init ML config index", e); + exceptionRef.set(e); latch.countDown(); - } + })); try { - latch.await(5, SECONDS); + latch.await(1, SECONDS); } catch (InterruptedException e) { throw new IllegalStateException(e); } @@ -142,4 +191,5 @@ private void initMasterKey() { throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR); } } + } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java index 211ea017c3..bd228fb665 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -2,6 +2,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; @@ -9,7 +10,9 @@ import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.engine.encryptor.EncryptorImpl.MASTER_KEY_NOT_READY_ERROR; +import java.io.IOException; import java.time.Instant; +import java.util.Map; import org.junit.Assert; import org.junit.Before; @@ -21,6 +24,7 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.Version; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -30,6 +34,10 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -46,6 +54,9 @@ public class EncryptorImplTest { @Mock ClusterState clusterState; + @Mock + private MLIndicesHandler mlIndicesHandler; + String masterKey; @Mock @@ -100,14 +111,319 @@ public void setUp() { } @Test - public void encrypt() { - Encryptor encryptor = new EncryptorImpl(clusterService, client); + public void encrypt_ExistingMasterKey() { + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()).thenReturn(Map.of(MASTER_KEY, masterKey)); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey()); String encrypted = encryptor.encrypt("test"); Assert.assertNotNull(encrypted); Assert.assertEquals(masterKey, encryptor.getMasterKey()); } + @Test + public void encrypt_NonExistingMasterKey() { + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + IndexResponse response = mock(IndexResponse.class); + actionListener.onResponse(response); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + Assert.assertNotNull(encrypted); + Assert.assertNotEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("random test exception"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("random test exception")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeException() { + exceptionRule.expect(MLException.class); + exceptionRule.expectMessage("random IO exception"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener.onFailure(new IOException("random IO exception")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener + .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_NullGetResponse() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = null; + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener + .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_NullResponse() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = null; + actionListener.onResponse(response); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener + .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_GetExistingMasterKey() { + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()).thenReturn(Map.of(MASTER_KEY, masterKey)); + actionListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener + .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + Assert.assertNotNull(encrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_FailedToGetExistingMasterKey() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("random test exception"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + actionListener.onResponse(response); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("random test exception")); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener + .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + Assert.assertNotNull(encrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void encrypt_ThrowExceptionWhenInitMLConfigIndex() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("test exception"); + doThrow(new RuntimeException("test exception")).when(mlIndicesHandler).initMLConfigIndex(any()); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + encryptor.encrypt(masterKey); + } + + @Test + public void encrypt_FailedToInitMLConfigIndex() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("random test exception"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onFailure(new RuntimeException("random test exception")); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + encryptor.encrypt(masterKey); + } + + @Test + public void encrypt_FailedToGetMasterKey() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("random test exception"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("random test exception")); + return null; + }).when(client).get(any(), any()); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + encryptor.encrypt(masterKey); + } + @Test public void encrypt_DifferentMasterKey() { Encryptor encryptor = new EncryptorImpl(masterKey); @@ -121,7 +437,22 @@ public void encrypt_DifferentMasterKey() { @Test public void decrypt() { - Encryptor encryptor = new EncryptorImpl(clusterService, client); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey()); String encrypted = encryptor.encrypt("test"); String decrypted = encryptor.decrypt(encrypted); @@ -142,7 +473,7 @@ public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey()); encryptor.encrypt("test"); } @@ -152,13 +483,18 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() { exceptionRule.expect(RuntimeException.class); exceptionRule.expectMessage("test error"); + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("test error")); return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey()); encryptor.decrypt("test"); } @@ -177,7 +513,7 @@ public void decrypt_MLConfigIndexNotFound() { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client); + Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey()); encryptor.decrypt("test"); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 6feaff32c7..44d75638f4 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -22,6 +22,7 @@ import java.util.concurrent.Semaphore; import java.util.stream.Collectors; +import org.opensearch.action.DocWriteRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.index.IndexRequest; @@ -231,6 +232,7 @@ void initMLConfig() { final String masterKey = encryptor.generateMasterKey(); indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + indexRequest.opType(DocWriteRequest.OpType.CREATE); client.index(indexRequest, ActionListener.wrap(indexResponse -> { log.info("ML configuration initialized successfully"); encryptor.setMasterKey(masterKey); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 215be84d4f..bbb4f2ea11 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -459,7 +459,8 @@ public Collection createComponents( Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); - encryptor = new EncryptorImpl(clusterService, client); + mlIndicesHandler = new MLIndicesHandler(clusterService, client); + encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); @@ -493,7 +494,6 @@ public Collection createComponents( stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); - mlIndicesHandler = new MLIndicesHandler(clusterService, client); mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler); modelHelper = new ModelHelper(mlEngine);