Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent feature corruption in LRU cache #4650

Merged
merged 10 commits into from
Jan 23, 2024
11 changes: 5 additions & 6 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,14 @@ void node_split(emt_tree& b, emt_node& cn)
cn.examples.clear();
}

void node_insert(emt_node& cn, std::unique_ptr<emt_example> ex)
void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr<emt_example> ex)
{
for (auto& cn_ex : cn.examples)
{
if (cn_ex->full == ex->full) { return; }
}
cn.examples.push_back(std::move(ex));
tree_bound(b, cn.examples.back().get());
}

emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex)
Expand Down Expand Up @@ -779,16 +780,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW:
auto* closest_ex = node_pick(b, base, cn, ex);
ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0;
ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0;
if (closest_ex != nullptr) { tree_bound(b, closest_ex); }
}

void emt_predict(emt_tree& b, learner& base, VW::example& ec)
{
b.all->feature_tweaks_config.ignore_some_linear = false;
emt_example ex(*b.all, &ec);

emt_node& cn = *tree_route(b, ex);
node_predict(b, base, cn, ex, ec);
tree_bound(b, &ex);
}

void emt_learn(emt_tree& b, learner& base, VW::example& ec)
Expand All @@ -797,10 +797,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec)
auto ex = VW::make_unique<emt_example>(*b.all, &ec);

emt_node& cn = *tree_route(b, *ex);
scorer_learn(b, base, cn, *ex, ec.weight);
node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn
tree_bound(b, ex.get());
node_insert(cn, std::move(ex));
scorer_learn(b, base, cn, *ex, ec.weight);
node_insert(b, cn, std::move(ex));
node_split(b, cn);
}

Expand Down
41 changes: 40 additions & 1 deletion vowpalwabbit/core/tests/eigen_memory_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest)
}
}

TEST(EigenMemoryTree, Bounding)
TEST(EigenMemoryTree, BoundingDrop)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5"));
auto* tree = get_emt_tree(*vw);
Expand All @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding)
EXPECT_EQ(tree->root->router_weights.size(), 0);
}

TEST(EigenMemoryTree, BoundingPredict)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

auto* ex = VW::read_example(*vw, "1 | 1");
vw->predict(*ex);
vw->finish_example(*ex);

EXPECT_EQ(tree->bounder->list.size(), 0);
}

TEST(EigenMemoryTree, BoundingRecency)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

for (int i = 0; i < 3; i++)
{
auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i));
vw->learn(*ex);
vw->finish_example(*ex);
}

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2);

auto* ex1 = VW::read_example(*vw, "1 | 1");
vw->predict(*ex1);
vw->finish_example(*ex1);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1);

auto* ex2 = VW::read_example(*vw, "1 | 0");
vw->predict(*ex2);
vw->finish_example(*ex2);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0);
}

TEST(EigenMemoryTree, Split)
{
auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3");
Expand Down
Loading