Skip to content

Commit

Permalink
Fix ALEBO statedict extraction for new gpytorch release
Browse files Browse the repository at this point in the history
Summary:
Fixes the test failures on the new release by correctly parsing the non-batched constraint values, and fixing the unit tests that didn't expect constraints in the state dict.

Depends on D24579330

Reviewed By: Balandat

Differential Revision: D24683186

fbshipit-source-id: c4e2e6b7effd5c90759666493fcde2707bae8d3f
  • Loading branch information
bletham authored and facebook-github-bot committed Nov 2, 2020
1 parent c741fef commit 723bb7f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions ax/models/tests/test_alebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,15 @@ def testALEBOGP(self):
# Test extract_map_statedict
map_sds = extract_map_statedict(m_b=m_b, num_outputs=1)
self.assertEqual(len(map_sds), 1)
self.assertEqual(len(map_sds[0]), 3)
self.assertEqual(len(map_sds[0]), 5)
self.assertEqual(
set(map_sds[0]),
{
"covar_module.base_kernel.Uvec",
"covar_module.raw_outputscale",
"mean_module.constant",
"covar_module.raw_outputscale_constraint.lower_bound",
"covar_module.raw_outputscale_constraint.upper_bound",
},
)
self.assertEqual(
Expand All @@ -138,13 +140,15 @@ def testALEBOGP(self):
map_sds = extract_map_statedict(m_b=ml, num_outputs=2)
self.assertEqual(len(map_sds), 2)
for i in range(2):
self.assertEqual(len(map_sds[i]), 3)
self.assertEqual(len(map_sds[i]), 5)
self.assertEqual(
set(map_sds[i]),
{
"covar_module.base_kernel.Uvec",
"covar_module.raw_outputscale",
"mean_module.constant",
"covar_module.raw_outputscale_constraint.lower_bound",
"covar_module.raw_outputscale_constraint.upper_bound",
},
)
self.assertEqual(
Expand Down
4 changes: 3 additions & 1 deletion ax/models/torch/alebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def extract_map_statedict(
else:
model_idx = 0
param_name = k
map_sds[model_idx][param_name] = torch.select(v, 0, 0)
if len(v.shape) > 1:
v = torch.select(v, 0, 0)
map_sds[model_idx][param_name] = v
return map_sds


Expand Down

0 comments on commit 723bb7f

Please sign in to comment.