Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Kiran Prakash <[email protected]>
  • Loading branch information
kiranprakash154 committed Aug 29, 2024
1 parent 4a2c51e commit cbb51bd
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.opensearch.monitor.jvm.JvmStats;
import org.opensearch.monitor.process.ProcessProbe;
import org.opensearch.search.ResourceType;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancellation;
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;

Expand Down Expand Up @@ -173,21 +175,52 @@ private boolean shouldCancelTasks(QueryGroup queryGroup, ResourceType resourceTy
}

private List<TaskCancellation> getTaskCancellations(QueryGroup queryGroup, ResourceType resourceType) {
return defaultTaskSelectionStrategy.selectTasksForCancellation(
queryGroup,
// get the active tasks in the query group
List<Task> selectedTasksToCancel = defaultTaskSelectionStrategy.selectTasksForCancellation(
queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks(),
getReduceBy(queryGroup, resourceType),
resourceType
);
List<TaskCancellation> taskCancellations = new ArrayList<>();
for(Task task : selectedTasksToCancel) {
String cancellationReason = createCancellationReason(queryGroup, task, resourceType);
taskCancellations.add(createTaskCancellation((CancellableTask) task, cancellationReason));
}
return taskCancellations;
}

private String createCancellationReason(QueryGroup querygroup, Task task, ResourceType resourceType) {
Double thresholdInPercent = getThresholdInPercent(querygroup, resourceType);
return "[Workload Management] Cancelling Task ID : "
+ task.getId()
+ " from QueryGroup ID : "
+ querygroup.get_id()
+ " breached the resource limit of : "
+ thresholdInPercent
+ " for resource type : "
+ resourceType.getName();
}

private Double getThresholdInPercent(QueryGroup querygroup, ResourceType resourceType) {
return ((Double) (querygroup.getResourceLimits().get(resourceType))) * 100;
}

private TaskCancellation createTaskCancellation(CancellableTask task, String cancellationReason) {
return new TaskCancellation(task, List.of(new TaskCancellation.Reason(cancellationReason, 5)), List.of(this::callbackOnCancel));
}

protected List<TaskCancellation> getTaskCancellationsForDeletedQueryGroup(QueryGroup queryGroup) {
return defaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup(
queryGroup,
// get the active tasks in the query group
List<Task> tasks = defaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup(
queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks()
);
List<TaskCancellation> taskCancellations = new ArrayList<>();
for(Task task : tasks) {
String cancellationReason = "[Workload Management] Cancelling Task ID : "
+ task.getId()
+ " from QueryGroup ID : "
+ queryGroup.get_id();
taskCancellations.add(createTaskCancellation((CancellableTask) task, cancellationReason));
}
return taskCancellations;
}

private long getReduceBy(QueryGroup queryGroup, ResourceType resourceType) {
Expand Down Expand Up @@ -229,4 +262,8 @@ private boolean isBreachingThreshold(ResourceType resourceType, Double resourceT
// Check if resource usage is breaching the threshold
return resourceUsageInMillis > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage);
}

private void callbackOnCancel() {
// TODO Implement callback logic here mostly used for Stats
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.opensearch.wlm.cancellation;

import org.opensearch.cluster.metadata.QueryGroup;
import org.opensearch.search.ResourceType;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -47,8 +46,7 @@ public Comparator<Task> sortingCondition() {
* @return The list of selected tasks
* @throws IllegalArgumentException If the limit is less than zero
*/
public List<TaskCancellation> selectTasksForCancellation(
QueryGroup querygroup,
public List<Task> selectTasksForCancellation(
List<Task> tasks,
long limit,
ResourceType resourceType
Expand All @@ -62,13 +60,11 @@ public List<TaskCancellation> selectTasksForCancellation(

List<Task> sortedTasks = tasks.stream().sorted(sortingCondition()).collect(Collectors.toList());

List<TaskCancellation> selectedTasks = new ArrayList<>();
List<Task> selectedTasks = new ArrayList<>();
long accumulated = 0;

for (Task task : sortedTasks) {
if (task instanceof CancellableTask) {
String cancellationReason = createCancellationReason(querygroup, task, resourceType);
selectedTasks.add(createTaskCancellation((CancellableTask) task, cancellationReason));
selectedTasks.add(task);
accumulated += resourceType.getResourceUsage(task);
if (accumulated >= limit) {
break;
Expand All @@ -84,46 +80,13 @@ public List<TaskCancellation> selectTasksForCancellation(
* {@link CancellableTask}. For each selected task, it creates a cancellation reason and adds
* a {@link TaskCancellation} object to the list of selected tasks.
*
* @param querygroup The {@link QueryGroup} from which the tasks are being selected.
* @param tasks The list of {@link Task} objects to be evaluated for cancellation.
* @return A list of {@link TaskCancellation} objects representing the tasks selected for cancellation.
*/
public List<TaskCancellation> selectTasksFromDeletedQueryGroup(QueryGroup querygroup, List<Task> tasks) {
List<TaskCancellation> selectedTasks = new ArrayList<>();

for (Task task : tasks) {
if (task instanceof CancellableTask) {
String cancellationReason = "[Workload Management] Cancelling Task ID : "
+ task.getId()
+ " from QueryGroup ID : "
+ querygroup.get_id();
selectedTasks.add(createTaskCancellation((CancellableTask) task, cancellationReason));
}
}
return selectedTasks;
}

private String createCancellationReason(QueryGroup querygroup, Task task, ResourceType resourceType) {
Double thresholdInPercent = getThresholdInPercent(querygroup, resourceType);
return "[Workload Management] Cancelling Task ID : "
+ task.getId()
+ " from QueryGroup ID : "
+ querygroup.get_id()
+ " breached the resource limit of : "
+ thresholdInPercent
+ " for resource type : "
+ resourceType.getName();
}

private Double getThresholdInPercent(QueryGroup querygroup, ResourceType resourceType) {
return ((Double) (querygroup.getResourceLimits().get(resourceType))) * 100;
}

private TaskCancellation createTaskCancellation(CancellableTask task, String cancellationReason) {
return new TaskCancellation(task, List.of(new TaskCancellation.Reason(cancellationReason, 5)), List.of(this::callbackOnCancel));
}

private void callbackOnCancel() {
// TODO Implement callback logic here mostly used for Stats
public List<Task> selectTasksFromDeletedQueryGroup(List<Task> tasks) {
return tasks
.stream()
.filter(task -> task instanceof CancellableTask)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import org.opensearch.wlm.QueryGroupLevelResourceUsageView;
import org.junit.Before;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BooleanSupplier;
import java.util.stream.Collectors;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -66,6 +68,35 @@ public void setup() {
);
}

public void testGetCancellableTasksFrom_setupAppropriateCancellationReasonAndScore() {
ResourceType resourceType = ResourceType.CPU;
long usage = 100_000_000L;
Double threshold = 0.1;

QueryGroup queryGroup1 = new QueryGroup(
"testQueryGroup",
queryGroupId1,
QueryGroup.ResiliencyMode.ENFORCED,
Map.of(resourceType, threshold),
1L
);
QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage);
queryGroupLevelViews.put(queryGroupId1, mockView);

List<TaskCancellation> cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1);
assertEquals(2, cancellableTasksFrom.size());
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());
assertEquals(
"[Workload Management] Cancelling Task ID : "
+ cancellableTasksFrom.get(0).getTask().getId()
+ " from QueryGroup ID : queryGroup1"
+ " breached the resource limit of : 10.0 for resource type : cpu",
cancellableTasksFrom.get(0).getReasonString()
);
assertEquals(5, cancellableTasksFrom.get(0).getReasons().get(0).getCancellationScore());
}

public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThreshold() {
ResourceType resourceType = ResourceType.CPU;
long usage = 100_000_000L;
Expand Down Expand Up @@ -216,8 +247,7 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() {
);

QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(resourceType, usage);
QueryGroupLevelResourceUsageView mockView2 = mock(QueryGroupLevelResourceUsageView.class);
when(mockView2.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1000), getRandomSearchTask(1001)));
QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock(resourceType, usage, List.of(1000, 1001));
queryGroupLevelViews.put(queryGroupId1, mockView1);
queryGroupLevelViews.put(queryGroupId2, mockView2);
activeQueryGroups.add(activeQueryGroup);
Expand All @@ -228,7 +258,7 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() {
queryGroupLevelViews,
activeQueryGroups,
deletedQueryGroups,
() -> false
() -> true
);

List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(QueryGroup.ResiliencyMode.ENFORCED);
Expand All @@ -251,6 +281,62 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() {
assertTrue(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled());
}

public void testCancelTasks_does_not_cancelTasksFromDeletedQueryGroups_whenNodeNotInDuress() {
ResourceType resourceType = ResourceType.CPU;
long usage = 150_000_000_000L;
Double threshold = 0.01;

QueryGroup activeQueryGroup = new QueryGroup(
"testQueryGroup",
queryGroupId1,
QueryGroup.ResiliencyMode.ENFORCED,
Map.of(resourceType, threshold),
1L
);

QueryGroup deletedQueryGroup = new QueryGroup(
"testQueryGroup",
queryGroupId2,
QueryGroup.ResiliencyMode.ENFORCED,
Map.of(resourceType, threshold),
1L
);

QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(resourceType, usage);
QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock(resourceType, usage, List.of(1000, 1001));
queryGroupLevelViews.put(queryGroupId1, mockView1);
queryGroupLevelViews.put(queryGroupId2, mockView2);
activeQueryGroups.add(activeQueryGroup);
deletedQueryGroups.add(deletedQueryGroup);

TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl(
new DefaultTaskSelectionStrategy(),
queryGroupLevelViews,
activeQueryGroups,
deletedQueryGroups,
() -> false
);

List<TaskCancellation> cancellableTasksFrom = taskCancellation.getAllCancellableTasks(QueryGroup.ResiliencyMode.ENFORCED);
assertEquals(2, cancellableTasksFrom.size());
assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId());
assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId());

List<TaskCancellation> cancellableTasksFromDeletedQueryGroups = taskCancellation.getTaskCancellationsForDeletedQueryGroup(
deletedQueryGroup
);
assertEquals(2, cancellableTasksFromDeletedQueryGroups.size());
assertEquals(1000, cancellableTasksFromDeletedQueryGroups.get(0).getTask().getId());
assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId());

taskCancellation.cancelTasks();

assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled());
assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled());
assertFalse(cancellableTasksFromDeletedQueryGroups.get(0).getTask().isCancelled());
assertFalse(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled());
}

public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() {
ResourceType resourceType = ResourceType.CPU;
long usage = 150_000_000_000L;
Expand Down Expand Up @@ -384,6 +470,21 @@ private QueryGroupLevelResourceUsageView createResourceUsageViewMock(ResourceTyp
return mockView;
}

private QueryGroupLevelResourceUsageView createResourceUsageViewMock(
ResourceType resourceType,
Long usage,
Collection<Integer> ids
) {
QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class);
when(mockView.getResourceUsageData()).thenReturn(Collections.singletonMap(resourceType, usage));
when(mockView.getActiveTasks()).thenReturn(
ids.stream()
.map(this::getRandomSearchTask)
.collect(Collectors.toList())
);
return mockView;
}

private Task getRandomSearchTask(long id) {
return new SearchTask(
id,
Expand Down
Loading

0 comments on commit cbb51bd

Please sign in to comment.