Skip to content

Commit

Permalink
fixed centroid generation bugs. all cvtarchive tests passed.
Browse files Browse the repository at this point in the history
  • Loading branch information
HenryChen4 committed Dec 11, 2023
1 parent 5e7d82d commit 2cd5599
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions ribs/archives/_cvt_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,27 @@ def __init__(self,
elif centroid_method == "sobol":
# Generate self._cells number of centroids as a Sobol sequence.
sampler = Sobol(d=self._measure_dim, scramble=False)
num_points = np.log2(self._cells).astype(int)
self._centroids = sampler.random_base2(num_points)
sobol_nums = sampler.random(n=self._cells)
lower = self._lower_bounds
upper = self._upper_bounds
scaled_sobol_nums = lower + sobol_nums * (upper - lower)
self._centroids = scaled_sobol_nums
elif centroid_method == "scrambled_sobol":
# Generates centroids as a scrambled Sobol sequence.
sampler = Sobol(d=self._measure_dim, scramble=True)
num_points = np.log2(self._cells).astype(int)
self._centroids = sampler.random_base2(num_points)
sobol_nums = sampler.random(n=self._cells)
lower = self._lower_bounds
upper = self._upper_bounds
scaled_sobol_nums = lower + sobol_nums * (upper - lower)
self._centroids = scaled_sobol_nums
elif centroid_method == "halton":
# Generates centroids using a Halton sequence.
sampler = Halton(d=self._measure_dim)
self._centroids = sampler.random(n=self._cells)
halton_nums = sampler.random(n=self._cells)
lower = self._lower_bounds
upper = self._upper_bounds
scaled_halton_nums = lower + halton_nums * (upper - lower)
self._centroids = scaled_halton_nums
else:
# Validate shape of `custom_centroids` when they are provided.
custom_centroids = np.asarray(custom_centroids, dtype=self.dtype)
Expand Down

0 comments on commit 2cd5599

Please sign in to comment.