Skip to content

Commit

Permalink
Split balanced
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 7, 2023
1 parent 25c9c8f commit 3da4f3c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions scripts/builtin/splitBalanced.dml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test,
Matrix[Double] y_test)
{

classes = table(Y, 1)
XY = order(target = cbind(Y, X), by = 1, decreasing=FALSE, index.return=FALSE)
# get the class count
classes = table(XY[, 1], 1)
split = floor(nrow(X) * splitRatio)
start_class = 1
train_row_s = 1
Expand All @@ -70,13 +70,14 @@ return (Matrix[Double] X_train, Matrix[Double] y_train, Matrix[Double] X_test,
{
end_class = end_class + as.scalar(classes[i])
class_t = XY[start_class:end_class, ]
ratio = as.scalar(classes_ratio_train[i])

train_row_e = train_row_e + as.scalar(classes_ratio_train[i])
train_row_e = train_row_e + ratio
test_row_e = test_row_e + as.scalar(classes_ratio_test[i])

outTrain[train_row_s:train_row_e, ] = class_t[1:as.scalar(classes_ratio_train[i]), ]
outTrain[train_row_s:train_row_e, ] = class_t[1:ratio, ]

outTest[test_row_s:test_row_e, ] = class_t[as.scalar(classes_ratio_train[i])+1:nrow(class_t), ]
outTest[test_row_s:test_row_e, ] = class_t[ratio+1:nrow(class_t), ]

train_row_s = train_row_e + 1
test_row_s = test_row_e + 1
Expand Down

0 comments on commit 3da4f3c

Please sign in to comment.