Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored HybridQueryPhaseSearcherTests to remove knn specific classes #977

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ coverage:
changes: yes

# disable comments in PRs
comment: yes
comment: yes
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.lucene.document.FieldType;
Expand All @@ -39,35 +36,19 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.index.Index;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.KNNQueryBuilder;

import com.carrotsearch.randomizedtesting.RandomizedTest;

import lombok.SneakyThrows;

public class HybridQueryTests extends OpenSearchQueryTestCase {

static final String VECTOR_FIELD_NAME = "vectorField";
static final String TERM_QUERY_TEXT = "keyword";
static final String TERM_ANOTHER_QUERY_TEXT = "anotherkeyword";
static final float[] VECTOR_QUERY = new float[] { 1.0f, 2.0f, 2.1f, 0.6f };
static final int K = 2;
@Mock
protected ClusterService clusterService;
private AutoCloseable openMocks;
Expand All @@ -76,11 +57,6 @@ public class HybridQueryTests extends OpenSearchQueryTestCase {
public void setUp() throws Exception {
super.setUp();
openMocks = MockitoAnnotations.openMocks(this);
// This is required to make sure that before every test we are initializing the KNNSettings. Not doing this
// leads to failures of unit tests cases when a unit test is run separately. Try running this test:
// ./gradlew ':test' --tests "org.opensearch.knn.training.TrainingJobTests.testRun_success" and see it fails
// but if run along with other tests this test passes.
initKNNSettings();
}

@Override
Expand Down Expand Up @@ -141,30 +117,16 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() {
w.commit();

IndexReader reader = DirectoryReader.open(w);

// Test with TermQuery
HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
Query rewritten = hybridQueryWithTerm.rewrite(reader);
// term query is the same after we rewrite it
assertSame(hybridQueryWithTerm, rewritten);

Index dummyIndex = new Index("dummy", "dummy");
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
KNNMethodContext mockKNNMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.ofNullable(mockKNNMethodContext));
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K);
Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext);

HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery));
rewritten = hybridQueryWithKnn.rewrite(reader);
assertSame(hybridQueryWithKnn, rewritten);

// Test empty query list
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of()));
assertThat(exception.getMessage(), containsString("collection of queries must not be empty"));

Expand Down Expand Up @@ -353,17 +315,4 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() {
}
assertEquals(2, countOfQueries);
}

private void initKNNSettings() {
Set<Setting<?>> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
defaultClusterSettings.addAll(
KNNSettings.state()
.getSettings()
.stream()
.filter(s -> s.getProperties().contains(Setting.Property.NodeScope))
.collect(Collectors.toList())
);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings));
KNNSettings.state().setClusterService(clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
import org.opensearch.index.remote.RemoteStoreEnums;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -80,7 +78,6 @@
import org.opensearch.search.query.ReduceableSearchResult;

public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String VECTOR_FIELD_NAME = "vectorField";
private static final String TEXT_FIELD_NAME = "field";
private static final String TEST_DOC_TEXT1 = "Hello world";
private static final String TEST_DOC_TEXT2 = "Hi to this place";
Expand All @@ -94,12 +91,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -153,8 +145,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple term queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down Expand Up @@ -826,12 +821,7 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() {
public void testAggregations_whenMetricAggregation_thenSuccessful() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
Expand Down Expand Up @@ -886,8 +876,11 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() {

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);
// Add multiple queries to simulate a complex hybrid query
TermQueryBuilder termSubQuery1 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2);
queryBuilder.add(termSubQuery1);
queryBuilder.add(termSubQuery2);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
Expand Down
Loading