From 3da4f3c6cd2a4eaa2372def731d791174b618209 Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Mon, 7 Aug 2023 12:35:50 +0200 Subject: [PATCH] Split balanced --- scripts/builtin/splitBalanced.dml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/builtin/splitBalanced.dml b/scripts/builtin/splitBalanced.dml index bb1d86bce87..3caad9e491a 100644 --- a/scripts/builtin/splitBalanced.dml +++ b/scripts/builtin/splitBalanced.dml @@ -43,9 +43,9 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test, Matrix[Double] y_test) { + classes = table(Y, 1) XY = order(target = cbind(Y, X), by = 1, decreasing=FALSE, index.return=FALSE) # get the class count - classes = table(XY[, 1], 1) split = floor(nrow(X) * splitRatio) start_class = 1 train_row_s = 1 @@ -70,13 +70,14 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test, { end_class = end_class + as.scalar(classes[i]) class_t = XY[start_class:end_class, ] + ratio = as.scalar(classes_ratio_train[i]) - train_row_e = train_row_e + as.scalar(classes_ratio_train[i]) + train_row_e = train_row_e + ratio test_row_e = test_row_e + as.scalar(classes_ratio_test[i]) - outTrain[train_row_s:train_row_e, ] = class_t[1:as.scalar(classes_ratio_train[i]), ] + outTrain[train_row_s:train_row_e, ] = class_t[1:ratio, ] - outTest[test_row_s:test_row_e, ] = class_t[as.scalar(classes_ratio_train[i])+1:nrow(class_t), ] + outTest[test_row_s:test_row_e, ] = class_t[ratio+1:nrow(class_t), ] train_row_s = train_row_e + 1 test_row_s = test_row_e + 1