diff --git a/src/main/java/com/fenbi/ytklearn/predictor/GBHMLROnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/GBHMLROnlinePredictor.java index 152bec2..a20b52a 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/GBHMLROnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/GBHMLROnlinePredictor.java @@ -143,7 +143,12 @@ public double score(Map features, Object other) { int vidx = ((j + 1) << 1) - 1; mu[idx + j] = mu[idx + vidx] + mu[idx + vidx + 1]; } - fx += learningRate * mu[idx]; + if (tree < treeNum - 1) { + fx += learningRate * mu[idx]; + } else { + fx += mu[idx]; + } + idx += stride; idxg += K; } diff --git a/src/main/java/com/fenbi/ytklearn/predictor/GBHSDTOnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/GBHSDTOnlinePredictor.java index fe0e1cc..62143c6 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/GBHSDTOnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/GBHSDTOnlinePredictor.java @@ -148,7 +148,12 @@ public double score(Map features, Object other) { int vidx = ((j + 1) << 1) - 1; mu[idxm + j] = mu[idxm + vidx] + mu[idxm + vidx + 1]; } - fx += learningRate * mu[idxm]; + + if (tree < treeNum - 1) { + fx += learningRate * mu[idxm]; + } else { + fx += mu[idxm]; + } idx += K - 1; idxg += K; diff --git a/src/main/java/com/fenbi/ytklearn/predictor/GBMLROnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/GBMLROnlinePredictor.java index 7b27d7b..60efd15 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/GBMLROnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/GBMLROnlinePredictor.java @@ -267,7 +267,12 @@ public double score(Map features, Object other) { lfx += gk_1 * wx[idx + stride - 1]; gating[tree * K + K - 1] = gk_1; - fx += learningRate * lfx; + if (tree < treeNum - 1) { + fx += learningRate * lfx; + } else { + fx += lfx; + } + idx += stride; } diff --git a/src/main/java/com/fenbi/ytklearn/predictor/GBSDTOnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/GBSDTOnlinePredictor.java index 57a0d7f..d60556b 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/GBSDTOnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/GBSDTOnlinePredictor.java @@ -244,7 +244,12 @@ public double score(Map features, Object other) { lfx += gk_1 * leaf[tree][vstart - 1]; gating[tree * K + K - 1] = gk_1; - fx += learningRate * lfx; + if (tree < treeNum - 1) { + fx += learningRate * lfx; + } else { + fx += lfx; + } + idx += stride; }