-
Notifications
You must be signed in to change notification settings - Fork 467
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYSTEMDS-3153] Missing value imputation via KNN-based methods
LDE project SoSe'23. Closes #1879.
- Loading branch information
Showing
4 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
|
||
# Imputes missing values, indicated by NaNs, using KNN-based methods | ||
# (k-nearest neighbors by euclidean distance). In order to avoid NaNs in | ||
# distance computation and meaningful nearest neighbor search, we initialize | ||
# the missing values by column means. Currently, only the column with the most | ||
# missing values is actually imputed. | ||
# | ||
# ------------------------------------------------------------------------------ | ||
# INPUT: | ||
# ------------------------------------------------------------------------------ | ||
# X Matrix with missing values, which are represented as NaNs | ||
# method Method used for imputing missing values with different performance | ||
# and accuracy tradeoffs: | ||
# 'dist' (default): Compute all-pairs distances and impute the | ||
# missing values by closest. O(N^2 * #features) | ||
# 'dist_missing': Compute distances between data and records with | ||
# missing values. O(N*M * #features), assuming | ||
# that the number of records with MV is M<<N. | ||
# 'dist_sample': Compute distances between sample of data and | ||
# records with missing values. O(S*M * #features) | ||
# with M<<N and S<<N, but suboptimal imputation. | ||
# seed Root seed value for random/sample calls for deterministic behavior | ||
# -1 for true randomization | ||
# ------------------------------------------------------------------------------ | ||
# | ||
# OUTPUT: | ||
# ------------------------------------------------------------------------------ | ||
# result Imputed dataset | ||
# ------------------------------------------------------------------------------ | ||
|
||
m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1) | ||
return(Matrix[Double] result) | ||
{ | ||
#TODO fix seed handling (only root seed) | ||
#TODO fix imputation for all columns with missing values | ||
|
||
#KNN-Imputation Script | ||
|
||
#Create a mask for placeholder and to check for missing values | ||
masked = is.nan(X) | ||
|
||
#Find the column containing multiple missing values | ||
missing_col = rowIndexMax(colSums(is.nan(X))) | ||
|
||
#Impute NaN value with temporary mean value of the column | ||
filled_matrix = imputeByMean(X, matrix(0, cols = ncol(X), rows = 1)) | ||
|
||
if(method == "dist") { | ||
#Calculate the distance using dist method after imputation with mean | ||
distance_matrix = dist(filled_matrix) | ||
|
||
#Change 0 value so rowIndexMin will ignore that diagonal value | ||
distance_matrix = replace(target = distance_matrix, pattern = 0, replacement = 999) | ||
|
||
#Get the minimum distance row-wise computation | ||
minimum_index = rowIndexMin(distance_matrix) | ||
|
||
#Position of missing values in per row in which column | ||
position = rowSums(is.nan(X)) | ||
position = position * minimum_index | ||
|
||
#Filter the 0 out | ||
I = (rowSums(is.nan(X))!=0) | ||
missing = removeEmpty(target=position, margin="rows", select=I) | ||
|
||
#Convert the value indices into 0/1 matrix to find location | ||
indices = table(missing, seq(1,nrow(filled_matrix)),odim1=nrow(filled_matrix),odim2=nrow(missing)) | ||
|
||
#Replace the index with value | ||
imputedValue = t(indices) %*% filled_matrix[,as.scalar(missing_col)] | ||
|
||
#Get the index location of the missing value | ||
pos = rowSums(is.nan(X)) | ||
missing_indices = seq(1, nrow(pos)) * pos | ||
|
||
#Put the replacement value in the missing indices | ||
I2 = removeEmpty(target=missing_indices, margin="rows") | ||
R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1) | ||
|
||
#Replace the masked column with to be imputed Value | ||
masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R | ||
} | ||
else if(method == "dist_missing") { | ||
#assuming small missing values | ||
#Split the matrix into containing NaN values (missing records) and not containing NaN values (M2 records) | ||
I = (rowSums(is.nan(X))!=0) | ||
missing = removeEmpty(target=filled_matrix, margin="rows", select=I) | ||
|
||
Y = (rowSums(is.nan(X))==0) | ||
M2 = removeEmpty(target=filled_matrix, margin = "rows", select = Y) | ||
|
||
#Calculate the euclidean distance between fully records and missing records, and then find the min value row wise | ||
dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing)) | ||
dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols = nrow(M2))) | ||
D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing))) | ||
minD = rowIndexMin(t(D)) | ||
|
||
#Convert the value indices into 0/1 matrix to find location | ||
indices = table(minD, seq(1,nrow(M2)),odim1=nrow(M2),odim2=nrow(minD)) | ||
|
||
#Replace the value | ||
imputedValue = t(indices) %*% M2[,as.scalar(missing_col)] | ||
|
||
#Get the index location of the missing value | ||
pos = rowSums(is.nan(X)) | ||
missing_indices = seq(1, nrow(pos)) * pos | ||
|
||
#Put the replacement value in the missing indices | ||
I2 = removeEmpty(target=missing_indices, margin="rows") | ||
R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1) | ||
|
||
#Update the masked value | ||
masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R | ||
} | ||
else if(method == "dist_sample"){ | ||
#assuming large missing values | ||
#Split the matrix into containing NaN values (missing records) and not containing NaN values (M2 records) | ||
I = (rowSums(is.nan(X))!=0) | ||
missing = removeEmpty(target=filled_matrix, margin="rows", select=I) | ||
|
||
Y = (rowSums(is.nan(X))==0) | ||
M3 = removeEmpty(target=filled_matrix, margin = "rows", select = Y) | ||
|
||
#Create a random subset | ||
random_matrix = ceiling(rand(rows = nrow(M3), cols = 1, min = 0, max = 1, sparsity = 0.5, seed = seed)) | ||
|
||
#ensure that random_matrix has at least 1 value | ||
if(as.scalar(colSums(random_matrix)) < 1) | ||
random_matrix = matrix(1, rows = nrow(M3), cols = 1) | ||
|
||
subset = M3 * random_matrix | ||
subset = removeEmpty(target=subset, margin = "rows", select = random_matrix) | ||
|
||
#Calculate the euclidean distance between fully records and missing records, and then find the min value row wise | ||
dotSubset = rowSums(subset * subset) %*% matrix(1, rows = 1, cols = nrow(missing)) | ||
dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols = nrow(subset))) | ||
D = sqrt(dotSubset + dotMissing - 2 * (subset %*% t(missing))) | ||
minD = rowIndexMin(t(D)) | ||
|
||
#Convert the value indices into 0/1 matrix to find location | ||
indices = table(minD, seq(1,nrow(subset)),odim1=nrow(subset),odim2=nrow(minD)) | ||
|
||
#Replace the value | ||
imputedValue = t(indices) %*% subset[,as.scalar(missing_col)] | ||
|
||
#Get the index location of the missing value | ||
pos = rowSums(is.nan(X)) | ||
missing_indices = seq(1, nrow(pos)) * pos | ||
|
||
#Put the replacement value in the missing indices | ||
I2 = removeEmpty(target=missing_indices, margin="rows") | ||
R = table(I2,1,imputedValue,odim1 = nrow(X), odim2=1) | ||
|
||
#Update the masked value | ||
masked[,as.scalar(missing_col)] = masked[,as.scalar(missing_col)] * R | ||
} | ||
else { | ||
print("Method is unknown or not yet implemented") | ||
} | ||
|
||
#Impute the value | ||
result = replace(target = X, pattern = NaN, replacement = 0) | ||
result = result + masked | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.sysds.test.functions.builtin.part1; | ||
|
||
import org.apache.sysds.common.Types; | ||
import org.apache.sysds.common.Types.ExecMode; | ||
import org.apache.sysds.common.Types.ExecType; | ||
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; | ||
import org.apache.sysds.test.AutomatedTestBase; | ||
import org.apache.sysds.test.TestConfiguration; | ||
import org.apache.sysds.test.TestUtils; | ||
import org.junit.Assert; | ||
import org.junit.Test; | ||
|
||
import java.io.IOException; | ||
|
||
public class BuiltinImputeKNNTest extends AutomatedTestBase { | ||
|
||
private final static String TEST_NAME = "imputeByKNN"; | ||
private final static String TEST_DIR = "functions/builtin/"; | ||
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinImputeKNNTest.class.getSimpleName() + "/"; | ||
|
||
private double eps = 10; | ||
@Override | ||
public void setUp() { | ||
TestUtils.clearAssertionInformation(); | ||
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B","B2"})); | ||
} | ||
|
||
@Test | ||
public void testDefaultCP()throws IOException{ | ||
runImputeKNN(true, Types.ExecType.CP); | ||
} | ||
|
||
@Test | ||
public void testDefaultSpark()throws IOException{ | ||
runImputeKNN(true, Types.ExecType.SPARK); | ||
} | ||
|
||
private void runImputeKNN(boolean defaultProb, ExecType instType) throws IOException { | ||
ExecMode platform_old = setExecMode(instType); | ||
try { | ||
loadTestConfiguration(getTestConfiguration(TEST_NAME)); | ||
String HOME = SCRIPT_DIR + TEST_DIR; | ||
fullDMLScriptName = HOME + TEST_NAME + ".dml"; | ||
programArgs = new String[] {"-args", DATASET_DIR+"Salaries.csv", | ||
"dist", "dist_missing", output("B"), output("B2")}; | ||
|
||
runTest(true, false, null, -1); | ||
|
||
//Compare matrices, check if the sum of the imputed value is roughly the same | ||
double sum1 = readDMLMatrixFromOutputDir("B").get(new CellIndex(1,1)); | ||
double sum2 = readDMLMatrixFromOutputDir("B2").get(new CellIndex(1,1)); | ||
Assert.assertEquals(sum1, sum2, eps); | ||
} | ||
finally { | ||
rtplatform = platform_old; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
# Prepare the data | ||
X = read($1, data_type="frame", format="csv", header=TRUE, naStrings= ["20"]); | ||
X = cbind(as.matrix(X[,4:5]), as.matrix(X[,7])) | ||
remove_col = is.nan(X) | ||
|
||
data = removeEmpty(target = X, margin = "rows", select = (remove_col[,1] != 1)) | ||
mask = is.nan(data) | ||
|
||
#Perform the KNN imputation | ||
result = imputeByKNN(X = data, method = $2) | ||
result2 = imputeByKNN(X = data, method = $3) | ||
|
||
#Get the imputed value | ||
I = (mask[,2] == 1); | ||
value = removeEmpty(target = result, margin = "rows", select = I) | ||
value2 = removeEmpty(target = result2, margin = "rows", select = I) | ||
|
||
#Get the sum of the imputed value | ||
value = colSums(value[,2]) | ||
value2 = colSums(value2[,2]) | ||
|
||
write(value, $4) | ||
write(value2, $5) |