From 2cd5599ad496d21a63651ba74441bdb5d9c833da Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Mon, 11 Dec 2023 13:33:39 -0800 Subject: [PATCH] fixed centroid generation bugs. all cvtarchive tests passed. --- ribs/archives/_cvt_archive.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ribs/archives/_cvt_archive.py b/ribs/archives/_cvt_archive.py index c98987e25..a200c7293 100644 --- a/ribs/archives/_cvt_archive.py +++ b/ribs/archives/_cvt_archive.py @@ -214,17 +214,27 @@ def __init__(self, elif centroid_method == "sobol": # Generate self._cells number of centroids as a Sobol sequence. sampler = Sobol(d=self._measure_dim, scramble=False) - num_points = np.log2(self._cells).astype(int) - self._centroids = sampler.random_base2(num_points) + sobol_nums = sampler.random(n=self._cells) + lower = self._lower_bounds + upper = self._upper_bounds + scaled_sobol_nums = lower + sobol_nums * (upper - lower) + self._centroids = scaled_sobol_nums elif centroid_method == "scrambled_sobol": # Generates centroids as a scrambled Sobol sequence. sampler = Sobol(d=self._measure_dim, scramble=True) - num_points = np.log2(self._cells).astype(int) - self._centroids = sampler.random_base2(num_points) + sobol_nums = sampler.random(n=self._cells) + lower = self._lower_bounds + upper = self._upper_bounds + scaled_sobol_nums = lower + sobol_nums * (upper - lower) + self._centroids = scaled_sobol_nums elif centroid_method == "halton": # Generates centroids using a Halton sequence. sampler = Halton(d=self._measure_dim) - self._centroids = sampler.random(n=self._cells) + halton_nums = sampler.random(n=self._cells) + lower = self._lower_bounds + upper = self._upper_bounds + scaled_halton_nums = lower + halton_nums * (upper - lower) + self._centroids = scaled_halton_nums else: # Validate shape of `custom_centroids` when they are provided. custom_centroids = np.asarray(custom_centroids, dtype=self.dtype)