From 7e4a2bd6c68c41473c52f6850ef9346a7383511e Mon Sep 17 00:00:00 2001 From: Ganesh Krishna Ramadurai Date: Mon, 27 Jan 2025 13:06:55 -0800 Subject: [PATCH] Concurrency optimization for graph native loading update (#2441) Signed-off-by: Ganesh Ramadurai --- CHANGELOG.md | 2 + .../memory/NativeMemoryCacheManager.java | 70 ++++- .../memory/NativeMemoryEntryContext.java | 82 ++++- .../memory/NativeMemoryLoadStrategy.java | 17 +- .../memory/NativeMemoryCacheManagerTests.java | 285 +++++++++++++++++- .../memory/NativeMemoryEntryContextTests.java | 39 ++- .../memory/NativeMemoryLoadStrategyTests.java | 10 +- 7 files changed, 485 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc1377ef1..2cf0c8313 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] - Add a new build mode, `FAISS_OPT_LEVEL=avx512_spr`, which enables the use of advanced AVX-512 instructions introduced with Intel(R) Sapphire Rapids (#2404)[https://github.com/opensearch-project/k-NN/pull/2404] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] +- Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345] + ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 76e94ee66..5641e6fa3 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -35,12 +35,14 @@ import java.util.Iterator; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; /** * Manages native memory allocations made by JNI. @@ -56,6 +58,7 @@ public class NativeMemoryCacheManager implements Closeable { private Cache cache; private Deque accessRecencyQueue; + private final ConcurrentHashMap indexLocks = new ConcurrentHashMap<>(); private final ExecutorService executor; private AtomicBoolean cacheCapacityReached; private long maxWeight; @@ -297,6 +300,55 @@ public CacheStats getCacheStats() { return cache.stats(); } + /** + * Opens a vector index with proper locking mechanism to ensure thread safety. + * The method uses a ReentrantLock to synchronize access to the index file and + * cleans up the lock when no other threads are waiting. + * + * @param key the unique identifier for the index + * @param nativeMemoryEntryContext the context containing vector index information + */ + private void openIndex(String key, NativeMemoryEntryContext nativeMemoryEntryContext) { + ReentrantLock indexFileLock = indexLocks.computeIfAbsent(key, k -> new ReentrantLock()); + try { + indexFileLock.lock(); + nativeMemoryEntryContext.openVectorIndex(); + } finally { + indexFileLock.unlock(); + if (!indexFileLock.hasQueuedThreads()) { + indexLocks.remove(key, indexFileLock); + } + } + } + + /** + * Retrieves an entry from the cache and updates its access recency if found. + * This method combines cache access with recency queue management to maintain + * the least recently used (LRU) order of cached entries. + * + * @param key the unique identifier for the cached entry + * @return the cached NativeMemoryAllocation if present, null otherwise + */ + private NativeMemoryAllocation getFromCacheAndUpdateRecency(String key) { + NativeMemoryAllocation result = cache.getIfPresent(key); + if (result != null) { + updateAccessRecency(key); + } + return result; + } + + /** + * Updates the access recency of a cached entry by moving it to the end of the queue. + * This method maintains the least recently used (LRU) order by removing the entry + * from its current position and adding it to the end of the queue. + * + * @param key the unique identifier for the cached entry whose recency needs to be updated + */ + private void updateAccessRecency(String key) { + accessRecencyQueue.remove(key); + accessRecencyQueue.addLast(key); + } + /** * Retrieves NativeMemoryAllocation associated with the nativeMemoryEntryContext. * @@ -329,7 +381,6 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryC // In case of a cache miss, least recently accessed entries are evicted in a blocking manner // before the new entry can be added to the cache. String key = nativeMemoryEntryContext.getKey(); - NativeMemoryAllocation result = cache.getIfPresent(key); // Cache Hit // In case of a cache hit, moving the item to the end of the recency queue adds @@ -337,15 +388,21 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryC // as lightweight as possible. Multiple approaches and their outcomes were documented // before moving forward with the current solution. // The details are outlined here: https://github.com/opensearch-project/k-NN/pull/2015#issuecomment-2327064680 + NativeMemoryAllocation result = getFromCacheAndUpdateRecency(key); if (result != null) { - accessRecencyQueue.remove(key); - accessRecencyQueue.addLast(key); return result; } // Cache Miss // Evict before put + // open the graph file before proceeding to load the graph into memory + openIndex(key, nativeMemoryEntryContext); synchronized (this) { + // recheck if another thread already loaded this entry into the cache + result = getFromCacheAndUpdateRecency(key); + if (result != null) { + return result; + } if (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight) { Iterator lruIterator = accessRecencyQueue.iterator(); while (lruIterator.hasNext() @@ -367,7 +424,12 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryC return result; } } else { - return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load); + // open graphFile before load + try (nativeMemoryEntryContext) { + String key = nativeMemoryEntryContext.getKey(); + openIndex(key, nativeMemoryEntryContext); + return cache.get(key, nativeMemoryEntryContext::load); + } } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 0af13fb46..9f8d4bcce 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -12,12 +12,16 @@ package org.opensearch.knn.index.memory; import lombok.Getter; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import java.io.IOException; import java.util.Map; @@ -26,7 +30,7 @@ /** * Encapsulates all information needed to load a component into native memory. */ -public abstract class NativeMemoryEntryContext { +public abstract class NativeMemoryEntryContext implements AutoCloseable { protected final String key; @@ -55,6 +59,19 @@ public String getKey() { */ public abstract Integer calculateSizeInKB(); + /** + * Opens the graph file by opening the corresponding indexInput so + * that it is available for graph loading + */ + + public void openVectorIndex() {} + + /** + * Provides the capability to close the closable objects in the {@link NativeMemoryEntryContext} + */ + @Override + public void close() {} + /** * Loads entry into memory. * @@ -62,6 +79,7 @@ public String getKey() { */ public abstract T load() throws IOException; + @Log4j2 public static class IndexEntryContext extends NativeMemoryEntryContext { @Getter @@ -75,6 +93,17 @@ public static class IndexEntryContext extends NativeMemoryEntryContext { diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 8cbdb4fd7..da3b73ec0 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -13,12 +13,9 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.IOContext; -import org.apache.lucene.store.IndexInput; import org.opensearch.core.action.ActionListener; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; -import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; @@ -88,10 +85,16 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde final int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024); // Try to open an index input then pass it down to native engine for loading an index. - try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE)) { - final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream); - final long indexAddress = JNIService.loadIndex(indexInputWithBuffer, indexEntryContext.getParameters(), knnEngine); - + // openVectorIndex takes care of opening the indexInput file + if (!indexEntryContext.isIndexGraphFileOpened()) { + throw new IllegalStateException("Index [" + indexEntryContext.getOpenSearchIndexName() + "] is not preloaded"); + } + try (indexEntryContext) { + final long indexAddress = JNIService.loadIndex( + indexEntryContext.indexInputWithBuffer, + indexEntryContext.getParameters(), + knnEngine + ); return createIndexAllocation(indexEntryContext, knnEngine, indexAddress, indexSizeKb, vectorFileName); } } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 5cdedf11b..3de474076 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -12,10 +12,21 @@ package org.opensearch.knn.index.memory; import com.google.common.cache.CacheStats; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IndexInput; import org.junit.Before; +import lombok.SneakyThrows; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.plugin.KNNPlugin; @@ -28,10 +39,24 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; +import java.util.HashSet; +import java.util.ArrayList; +import java.util.Set; +import java.util.List; import java.util.concurrent.ExecutionException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.spy; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.plugin.stats.StatNames.GRAPH_MEMORY_USAGE; @@ -39,8 +64,27 @@ public class NativeMemoryCacheManagerTests extends OpenSearchSingleNodeTestCase private ThreadPool threadPool; + @Mock + protected ClusterService clusterService; + @Mock + protected ClusterSettings clusterSettings; + + protected AutoCloseable openMocks; + @Before - public void setThreadPool() { + public void setup() { + openMocks = MockitoAnnotations.openMocks(this); + clusterService = mock(ClusterService.class); + Set> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + defaultClusterSettings.addAll( + KNNSettings.state() + .getSettings() + .stream() + .filter(s -> s.getProperties().contains(Setting.Property.NodeScope)) + .collect(Collectors.toList()) + ); + KNNSettings.state().setClusterService(clusterService); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); threadPool = new ThreadPool(Settings.builder().put("node.name", "NativeMemoryCacheManagerTests").build()); NativeMemoryCacheManager.setThreadPool(threadPool); } @@ -493,6 +537,242 @@ public void testMaintenanceScheduled() { assertTrue(maintenanceTask.isCancelled()); } + @Test + public void checkFeatureFlag() { + KNNSettings.state().setClusterService(clusterService); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNNFeatureFlags.KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(true); + assertTrue(KNNFeatureFlags.isForceEvictCacheEnabled()); + when(clusterSettings.get(KNNFeatureFlags.KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(false); + assertFalse(KNNFeatureFlags.isForceEvictCacheEnabled()); + } + + @SneakyThrows + @Test + public void testGet() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + Map> indicesStats = nativeMemoryCacheManager.getIndicesCacheStats(); + assertTrue(indicesStats.isEmpty()); + + String indexName1 = "test-index-1"; + String testKey1 = "test-1"; + int size1 = 3; + NativeMemoryAllocation.IndexAllocation indexAllocation1 = new NativeMemoryAllocation.IndexAllocation( + null, + 0, + size1, + null, + testKey1, + indexName1 + ); + + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext indexEntryContext1 = spy( + new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, + TestUtils.createFakeNativeMamoryCacheKey("test"), + indexLoadStrategy, + null, + "test" + ) + ); + + doReturn(indexAllocation1).when(indexEntryContext1).load(); + + doReturn(0).when(indexEntryContext1).calculateSizeInKB(); + Directory mockDirectory = mock(Directory.class); + IndexInput mockReadStream = mock(IndexInput.class); + when(mockDirectory.openInput(any(), any())).thenReturn(mockReadStream); + // Add this line to handle the fileLength call + when(mockDirectory.fileLength(any())).thenReturn(1024L); // 1KB for testing + doReturn(mockDirectory).when(indexEntryContext1).getDirectory(); + assertFalse(indexEntryContext1.isIndexGraphFileOpened()); + assertEquals(indexAllocation1, nativeMemoryCacheManager.get(indexEntryContext1, false)); + // try-with-resources will anyway close the resources opened by indexEntryContext1 + assertFalse(indexEntryContext1.isIndexGraphFileOpened()); + assertEquals(indexAllocation1, nativeMemoryCacheManager.get(indexEntryContext1, false)); + + verify(mockDirectory, times(2)).openInput(any(), any()); + verify(mockReadStream, times(2)).seek(0); + verify(mockReadStream, times(2)).close(); + + } + + @SneakyThrows + @Test(expected = NullPointerException.class) + public void testGetWithInvalidFile_NullPointerException() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = spy( + new NativeMemoryEntryContext.IndexEntryContext((Directory) null, "invalid-cache-key", indexLoadStrategy, null, "test") + ); + + Directory mockDirectory = mock(Directory.class); + // This should throw the exception + nativeMemoryCacheManager.get(indexEntryContext, false); + } + + @SneakyThrows + @Test(expected = IllegalStateException.class) + public void testGetWithInvalidFile_IllegalStateException() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = spy( + new NativeMemoryEntryContext.IndexEntryContext((Directory) null, "invalid-cache-key", indexLoadStrategy, null, "test") + ); + + doReturn(0).when(indexEntryContext).calculateSizeInKB(); + Directory mockDirectory = mock(Directory.class); + // This should throw the exception + nativeMemoryCacheManager.get(indexEntryContext, false); + } + + @SneakyThrows + @Test + public void getWithForceEvictEnabled() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + clusterService = mock(ClusterService.class); + KNNSettings.state().setClusterService(clusterService); + clusterSettings = mock(ClusterSettings.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNNFeatureFlags.KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(true); + + String testKey1 = "test-1"; + String indexName1 = "test-index-1"; + int size1 = 3; + + NativeMemoryAllocation.IndexAllocation indexAllocation1 = new NativeMemoryAllocation.IndexAllocation( + null, + 0, + size1, + null, + testKey1, + indexName1 + ); + + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext indexEntryContext1 = spy( + new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, + TestUtils.createFakeNativeMamoryCacheKey("test"), + indexLoadStrategy, + null, + "test" + ) + ); + + doReturn(indexAllocation1).when(indexEntryContext1).load(); + doReturn(0).when(indexEntryContext1).calculateSizeInKB(); + Directory mockDirectory = mock(Directory.class); + IndexInput mockReadStream = mock(IndexInput.class); + when(mockDirectory.openInput(any(), any())).thenReturn(mockReadStream); + when(mockDirectory.fileLength(any())).thenReturn(1024L); + doReturn(mockDirectory).when(indexEntryContext1).getDirectory(); + + assertFalse(indexEntryContext1.isIndexGraphFileOpened()); + assertEquals(indexAllocation1, nativeMemoryCacheManager.get(indexEntryContext1, false)); + // In force evict path, the file should stay open since it's not in a try-with-resources + assertTrue(indexEntryContext1.isIndexGraphFileOpened()); + + assertEquals(indexAllocation1, nativeMemoryCacheManager.get(indexEntryContext1, false)); + assertTrue(indexEntryContext1.isIndexGraphFileOpened()); + + // Should only be called once since second call is a cache hit + verify(mockDirectory, times(1)).openInput(any(), any()); + verify(mockReadStream, times(1)).seek(0); + // Since we're not closing in try-with-resources, close shouldn't be called + verify(mockReadStream, never()).close(); + } + + @Test + @SneakyThrows + public void testConcurrentVectorIndexOpening() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + clusterService = mock(ClusterService.class); + KNNSettings.state().setClusterService(clusterService); + clusterSettings = mock(ClusterSettings.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNNFeatureFlags.KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(true); + + int numThreads = 5; + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch completionLatch = new CountDownLatch(numThreads); + AtomicInteger openVectorIndexCalls = new AtomicInteger(0); + + // Create test allocation + NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( + null, + 0, + 3, + null, + "test-1", + "test-index-1" + ); + + // Create and set up the spy context that will be shared across threads + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext sharedContext = spy( + new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, + TestUtils.createFakeNativeMamoryCacheKey("test"), + indexLoadStrategy, + null, + "test" + ) + ); + + // Set up mocks + doReturn(indexAllocation).when(sharedContext).load(); + doReturn(0).when(sharedContext).calculateSizeInKB(); + Directory mockDirectory = mock(Directory.class); + IndexInput mockReadStream = mock(IndexInput.class); + when(mockDirectory.openInput(any(), any())).thenReturn(mockReadStream); + when(mockDirectory.fileLength(any())).thenReturn(1024L); + doReturn(mockDirectory).when(sharedContext).getDirectory(); + + // Add a delay in openVectorIndex to make concurrent access more likely + doAnswer(invocation -> { + openVectorIndexCalls.incrementAndGet(); + // Add a small delay to simulate work + Thread.sleep(1000); + return invocation.callRealMethod(); + }).when(sharedContext).openVectorIndex(); + + // Create threads that will try to get the same context concurrently + List threads = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + Thread t = new Thread(() -> { + try { + startLatch.await(); // Wait for all threads to be ready + nativeMemoryCacheManager.get(sharedContext, false); + } catch (Exception e) { + e.printStackTrace(); + } finally { + completionLatch.countDown(); + } + }); + threads.add(t); + t.start(); + } + + startLatch.countDown(); + + // Wait for all threads to complete + completionLatch.await(); + + // openVectorIndex is called for each of the threads + verify(sharedContext, times(numThreads)).openVectorIndex(); + assertEquals(numThreads, openVectorIndexCalls.get()); + + // but opening of the indexInput and seek only happens once, since rest of the threads will wait for first + // thread and then pick up from cache + verify(mockDirectory, times(1)).openInput(any(), any()); + verify(mockReadStream, times(1)).seek(0); + + } + private static class TestNativeMemoryAllocation implements NativeMemoryAllocation { int size; @@ -571,6 +851,9 @@ public Integer calculateSizeInKB() { return size; } + @Override + public void openVectorIndex() {} + @Override public TestNativeMemoryAllocation load() throws IOException { return new TestNativeMemoryAllocation(size, memoryAddress); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index 5379abc74..f5d1ee77c 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -30,6 +30,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.doReturn; public class NativeMemoryEntryContextTests extends KNNTestCase { @@ -41,6 +43,34 @@ public void testAbstract_getKey() { } public void testIndexEntryContext_load() throws IOException { + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = spy( + new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, + TestUtils.createFakeNativeMamoryCacheKey("test"), + indexLoadStrategy, + null, + "test" + ) + ); + + NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( + null, + 0, + 10, + KNNEngine.DEFAULT, + "test-path", + "test-name" + ); + + when(indexLoadStrategy.load(indexEntryContext)).thenReturn(indexAllocation); + + // since we are returning mock instance, set indexEntryContext.isIndexGraphFileOpened to true. + doReturn(true).when(indexEntryContext).isIndexGraphFileOpened(); + assertEquals(indexAllocation, indexEntryContext.load()); + } + + public void testIndexEntryContext_load_with_unopened_graphFile() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( (Directory) null, @@ -59,9 +89,7 @@ public void testIndexEntryContext_load() throws IOException { "test-name" ); - when(indexLoadStrategy.load(indexEntryContext)).thenReturn(indexAllocation); - - assertEquals(indexAllocation, indexEntryContext.load()); + assertThrows(IllegalStateException.class, indexEntryContext::load); } public void testIndexEntryContext_calculateSize() throws IOException { @@ -292,6 +320,11 @@ public Integer calculateSizeInKB() { return size; } + @Override + public void openVectorIndex() { + return; + } + @Override public TestNativeMemoryAllocation load() throws IOException { return null; diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 735974bd1..a149185a8 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -68,9 +68,10 @@ public void testIndexLoadStrategy_load() throws IOException { "test" ); + // open graph file before load + indexEntryContext.openVectorIndex(); // Load - NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() - .load(indexEntryContext); + NativeMemoryAllocation.IndexAllocation indexAllocation = indexEntryContext.load(); // Confirm that the file was loaded by querying float[] query = new float[dimension]; @@ -114,9 +115,10 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { "test" ); + // open graph file before load + indexEntryContext.openVectorIndex(); // Load - NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() - .load(indexEntryContext); + NativeMemoryAllocation.IndexAllocation indexAllocation = indexEntryContext.load(); // Verify assertTrue(indexAllocation.isBinaryIndex());