Skip to content

Commit

Permalink
[MINOR] Frame tests improvement 2
Browse files Browse the repository at this point in the history
Add tests 100% test coverage for Frame/data/lib

Closes #2120
  • Loading branch information
Baunsgaard committed Sep 27, 2024
1 parent d80e3a6 commit 3b4f6cd
Show file tree
Hide file tree
Showing 9 changed files with 474 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;

public class FrameLibAppend {

protected static final Log LOG = LogFactory.getLog(FrameLibAppend.class.getName());

private FrameLibAppend(){
// private constructor.
}

/**
* Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects.
* For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended
Expand All @@ -50,7 +54,7 @@ public static FrameBlock append(FrameBlock a, FrameBlock b, boolean cbind) {
return ret;
}

public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
private static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
final int nRow = a.getNumRows();
final int nRowB = b.getNumRows();

Expand All @@ -73,7 +77,7 @@ else if(b.getNumColumns() == 0)
return new FrameBlock(_schema, _colnames, _colmeta, _coldata);
}

public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
final int nCol = a.getNumColumns();
final int nColB = b.getNumColumns();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

Expand Down Expand Up @@ -67,11 +66,16 @@ public static FrameBlock detectSchema(FrameBlock in, double sampleFraction, int
}

private FrameBlock apply() {
final int cols = in.getNumColumns();
final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply();
fb.appendRow(schemaInfo);
return fb;
try{
final int cols = in.getNumColumns();
final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply();
fb.appendRow(schemaInfo);
return fb;
}
catch(Exception e){
throw new DMLRuntimeException("Failed to detect schema", e);
}
}

private String[] singleThreadApply() {
Expand All @@ -84,7 +88,7 @@ private String[] singleThreadApply() {
return schemaInfo;
}

private String[] parallelApply() {
private String[] parallelApply() throws Exception {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final int cols = in.getNumColumns();
Expand All @@ -99,9 +103,6 @@ private String[] parallelApply() {

return schemaInfo;
}
catch(ExecutionException | InterruptedException e) {
throw new DMLRuntimeException("Exception interrupted or exception thrown in detectSchema", e);
}
finally{
pool.shutdown();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,33 +290,30 @@ public static ValueType isType(double val, ValueType min) {
}

public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) {
String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next();
String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next();
final int nCol = temp1.getNumColumns();

if(rowTemp1.length != rowTemp2.length)
throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length);
if(nCol != temp2.getNumColumns())
throw new DMLRuntimeException("Schema dimension mismatch: " + nCol + " vs " + temp2.getNumColumns());

for(int i = 0; i < rowTemp1.length; i++) {
// hack reuse input temp1 schema, only valid if temp1 never change schema.
// However, this is typically valid.
FrameBlock mergedFrame = new FrameBlock(temp1.getSchema());
mergedFrame.ensureAllocatedColumns(1);
for(int i = 0; i < nCol; i++) {
String s1 = (String) temp1.get(0, i);
String s2 = (String) temp2.get(0, i);
// modify schema1 if necessary (different schema2)
if(!rowTemp1[i].equals(rowTemp2[i])) {
if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING"))
rowTemp1[i] = "STRING";
else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64"))
rowTemp1[i] = "FP64";
else if(rowTemp1[i].equals("FP32") &&
new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i]))
rowTemp1[i] = "FP32";
else if(rowTemp1[i].equals("INT64") &&
new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i]))
rowTemp1[i] = "INT64";
else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER"))
rowTemp1[i] = "INT32";
if(!s1.equals(s2)) {
ValueType v1 = ValueType.valueOf(s1);
ValueType v2 = ValueType.valueOf(s2);
ValueType vc = ValueType.getHighestCommonTypeSafe(v1, v2);
mergedFrame.set(0, i, vc.toString());
}
else{
mergedFrame.set(0, i, s1);
}
}

// create output block one row representing the schema as strings
FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING));
mergedFrame.appendRow(rowTemp1);
return mergedFrame;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
Expand All @@ -40,28 +41,56 @@ public interface MatrixBlockFromFrame {
* Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type
* double, we do a best effort conversion of non-double types which might result in errors for non-numerical data.
*
* @param frame frame block
* @param k parallelization degree
* @return matrix block
* @param frame Frame block to convert
* @param k The parallelization degree
* @return MatrixBlock
*/
public static MatrixBlock convertToMatrixBlock(FrameBlock frame, int k) {
final int m = frame.getNumRows();
final int n = frame.getNumColumns();
final MatrixBlock mb = new MatrixBlock(m, n, false);
mb.allocateDenseBlock();
if(k == -1)
k = InfrastructureAnalyzer.getLocalParallelism();

long nnz = 0;
if(k == 1)
nnz = convert(frame, mb, n, 0, m);
else
nnz = convertParallel(frame, mb, m, n, k);
return convertToMatrixBlock(frame, null, k);
}

mb.setNonZeros(nnz);
/**
* Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type
* double, we do a best effort conversion of non-double types which might result in errors for non-numerical data.
*
* @param frame FrameBlock to convert
* @param ret The returned MatrixBlock
* @param k The parallelization degree
* @return MatrixBlock
*/
public static MatrixBlock convertToMatrixBlock(FrameBlock frame, MatrixBlock ret, int k) {
try {

mb.examSparsity();
return mb;
final int m = frame.getNumRows();
final int n = frame.getNumColumns();
ret = allocateRet(ret, m, n);

if(k == -1)
k = InfrastructureAnalyzer.getLocalParallelism();

long nnz = 0;
if(k == 1)
nnz = convert(frame, ret, n, 0, m);
else
nnz = convertParallel(frame, ret, m, n, k);

ret.setNonZeros(nnz);
ret.examSparsity();
return ret;
}
catch(Exception e) {
throw new DMLRuntimeException("Failed to convert FrameBlock to MatrixBlock", e);
}
}

private static MatrixBlock allocateRet(MatrixBlock ret, final int m, final int n) {
if(ret == null)
ret = new MatrixBlock(m, n, false);
else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseFormat())
ret.reset(m, n, false);
if(!ret.isAllocated())
ret.allocateDenseBlock();
return ret;
}

private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) {
Expand All @@ -71,27 +100,25 @@ private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int
return convertGeneric(frame, mb, n, rl, ru);
}

private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k){
private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception {
ExecutorService pool = CommonThreadPool.get(k);
try{
try {
List<Future<Long>> tasks = new ArrayList<>();
final int blkz = Math.max(m / k, 1000);

for( int i = 0; i < m; i+= blkz){
final int start = i;
for(int i = 0; i < m; i += blkz) {
final int start = i;
final int end = Math.min(i + blkz, m);
tasks.add(pool.submit(() -> convert(frame, mb, n, start, end)));
}

long nnz = 0;
for( Future<Long> t : tasks)
for(Future<Long> t : tasks)
nnz += t.get();
return nnz;
}
catch(Exception e){
throw new RuntimeException(e);
}
finally{

finally {
pool.shutdown();
}
}
Expand All @@ -104,29 +131,42 @@ private static long convertContiguous(final FrameBlock frame, final MatrixBlock
for(int bj = 0; bj < n; bj += blocksizeIJ) {
int bimin = Math.min(bi + blocksizeIJ, ru);
int bjmin = Math.min(bj + blocksizeIJ, n);
for(int i = bi, aix = bi * n; i < bimin; i++, aix += n)
for(int j = bj; j < bjmin; j++)
lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
lnnz = convertBlockContiguous(frame, n, lnnz, c, bi, bj, bimin, bjmin);
}
}
return lnnz;
}

private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, final int ru) {
private static long convertBlockContiguous(final FrameBlock frame, final int n, long lnnz, double[] c, int rl,
int cl, int ru, int cu) {
for(int i = rl, aix = rl * n; i < ru; i++, aix += n)
for(int j = cl; j < cu; j++)
lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
return lnnz;
}

private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl,
final int ru) {
long lnnz = 0;
final DenseBlock c = mb.getDenseBlock();
for(int bi = rl; bi < ru; bi += blocksizeIJ) {
for(int bj = 0; bj < n; bj += blocksizeIJ) {
int bimin = Math.min(bi + blocksizeIJ, ru);
int bjmin = Math.min(bj + blocksizeIJ, n);
for(int i = bi; i < bimin; i++) {
double[] cvals = c.values(i);
int cpos = c.pos(i);
for(int j = bj; j < bjmin; j++)
lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
}
lnnz = convertBlockGeneric(frame, lnnz, c, bi, bj, bimin, bjmin);
}
}
return lnnz;
}

private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl,
final int cl, final int ru, final int cu) {
for(int i = rl; i < ru; i++) {
final double[] cvals = c.values(i);
final int cpos = c.pos(i);
for(int j = cl; j < cu; j++)
lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
}
return lnnz;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ private void resetSparse() {
if(sparseBlock == null)
return;
sparseBlock.reset(estimatedNNzsPerRow, clen);
denseBlock = null;
}

private void resetDense(double val) {
Expand All @@ -343,6 +344,7 @@ else if( val != 0 ) {
allocateDenseBlock(false);
denseBlock.set(val);
}
sparseBlock = null;
}

private void resetDense(double val, boolean dedup) {
Expand All @@ -354,6 +356,7 @@ else if( val != 0 ) {
allocateDenseBlock(false, dedup);
denseBlock.set(val);
}
sparseBlock = null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@

package org.apache.sysds.test.component.frame;

import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend;
import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;

public class FrameCustomTest {
protected static final Log LOG = LogFactory.getLog(FrameCustomTest.class.getName());

@Test
public void castToFrame() {
Expand Down Expand Up @@ -61,4 +71,30 @@ public void castToFrame2() {
assertTrue(f.getSchema()[0] == ValueType.FP64);
}


@Test
public void detectSchemaError(){
FrameBlock f = TestUtils.generateRandomFrameBlock(10, 10, 23);
FrameBlock spy = spy(f);
when(spy.getColumn(anyInt())).thenThrow(new RuntimeException());

Exception e = assertThrows(DMLRuntimeException.class, () -> FrameLibDetectSchema.detectSchema(spy, 3));

assertTrue(e.getMessage().contains("Failed to detect schema"));
}



@Test
public void appendUniqueColNames(){
FrameBlock a = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"Hi"});
a.appendRow(new String[]{"0.2"});
FrameBlock b = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"There"});
b.appendRow(new String[]{"0.5"});

FrameBlock c = FrameLibAppend.append(a, b, true);

assertTrue(c.getColumnName(0).equals("Hi"));
assertTrue(c.getColumnName(1).equals("There"));
}
}
Loading

0 comments on commit 3b4f6cd

Please sign in to comment.