Skip to content

Commit

Permalink
unit tests for arrow lib
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhmaurya committed Oct 4, 2024
1 parent bd45429 commit 251f081
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 7 deletions.
8 changes: 2 additions & 6 deletions libs/arrow/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
dependencies {
api project(':libs:opensearch-common')
implementation 'org.slf4j:slf4j-api:1.7.36'
api group: 'org.apache.arrow', name: 'arrow-vector', version: '17.0.0'
api 'org.apache.arrow:arrow-vector:17.0.0'
api 'org.apache.arrow:arrow-memory-core:17.0.0'
api 'org.apache.arrow:arrow-format:17.0.0'
api group: 'org.apache.arrow', name: 'arrow-memory-netty-buffer-patch', version: '17.0.0'
api 'org.apache.arrow:arrow-memory-netty-buffer-patch:17.0.0'
api 'org.apache.arrow:arrow-memory-netty:17.0.0'
api("io.netty:netty-common:${versions.netty}") {
exclude group: 'io.netty', module: 'netty-common'
Expand All @@ -35,7 +35,3 @@ dependencies {
testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}"
testImplementation(project(":test:framework"))
}

tasks.named('forbiddenApisMain').configure {
replaceSignatureFiles 'jdk-signatures'
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public void removeStream(StreamTicket ticket) {
streams.remove(ticket);
}

public ConcurrentHashMap<StreamTicket, ArrowStreamProvider> getStreams() {
return streams;
}

public abstract StreamTicket generateUniqueTicket();

public void close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import java.util.Arrays;

public abstract class StreamTicket {
public class StreamTicket {
private final byte[] bytes;
public StreamTicket(byte[] bytes) {
this.bytes = bytes;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.arrow;

import org.apache.arrow.memory.BufferAllocator;
import org.mockito.Mock;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.Mockito.*;

public class StreamManagerTests extends OpenSearchTestCase {

private StreamManager streamManager;

@Mock
private ArrowStreamProvider mockProvider;

private final VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class);

@Override
public void setUp() throws Exception {
super.setUp();
streamManager = new StreamManager() {
@Override
public VectorSchemaRoot getVectorSchemaRoot(StreamTicket ticket) {
return mockRoot;
}

@Override
public StreamTicket generateUniqueTicket() {
return new StreamTicket(("ticket" + (getStreams().size()+1)).getBytes());
}
};
mockProvider = allocator -> new ArrowStreamProvider.Task() {
@Override
public VectorSchemaRoot init(BufferAllocator allocator) {
return mockRoot;
}

@Override
public void run(VectorSchemaRoot root, ArrowStreamProvider.FlushSignal flushSignal) {

}
};
}

public void testRegisterStream() {
StreamTicket ticket = streamManager.registerStream(mockProvider);
assertNotNull(ticket);
assertEquals(new StreamTicket("ticket1".getBytes()), ticket);
}

public void testGetStream() {
StreamTicket ticket = streamManager.registerStream(mockProvider);
ArrowStreamProvider retrievedProvider = streamManager.getStream(ticket);
assertEquals(mockProvider, retrievedProvider);
}

public void testGetVectorSchemaRoot() {
StreamTicket ticket = streamManager.registerStream(mockProvider);
VectorSchemaRoot root = streamManager.getVectorSchemaRoot(ticket);
assertEquals(mockRoot, root);
}

public void testRemoveStream() {
StreamTicket ticket = streamManager.registerStream(mockProvider);
streamManager.removeStream(ticket);
assertNull(streamManager.getStream(ticket));
}

public void testClose() {
StreamTicket ticket = streamManager.registerStream(mockProvider);
streamManager.close();
assertNull(streamManager.getStream(ticket));
}

public void testMultipleStreams() {
ArrowStreamProvider mockProvider2 = mock(ArrowStreamProvider.class);

StreamTicket ticket1 = streamManager.registerStream(mockProvider);
StreamTicket ticket2 = streamManager.registerStream(mockProvider2);
assertNotEquals(ticket1, ticket2);
assertEquals(2, streamManager.getStreams().size());
}

public void testInvalidTicket() {
StreamTicket invalidTicket = new StreamTicket("invalid-ticket".getBytes());
assertNull(streamManager.getStream(invalidTicket));
}
}

0 comments on commit 251f081

Please sign in to comment.