Skip to content

Commit

Permalink
Switch as_pandas test to test data
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Nov 10, 2023
1 parent 78061ac commit 4f782db
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions tests/archives/archive_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,36 +395,27 @@ def test_sample_elites_fails_when_empty(data):

@pytest.mark.parametrize("name", ARCHIVE_NAMES)
@pytest.mark.parametrize("with_elite", [True, False], ids=["nonempty", "empty"])
@pytest.mark.parametrize("include_solutions", [True, False],
ids=["solutions", "no_solutions"])
@pytest.mark.parametrize("include_metadata", [True, False],
ids=["metadata", "no_metadata"])
@pytest.mark.parametrize("dtype", [np.float64, np.float32],
ids=["float64", "float32"])
def test_as_pandas(name, with_elite, include_solutions, include_metadata,
dtype):
def test_pandas_data(name, with_elite, dtype):
data = get_archive_data(name, dtype)

# Set up expected columns and data types.
measure_cols = [f"measures_{i}" for i in range(len(data.measures))]
expected_cols = ["index"] + measure_cols + ["objective"]
expected_dtypes = [np.int32, *[dtype for _ in measure_cols], dtype]
if include_solutions:
solution_cols = [f"solution_{i}" for i in range(len(data.solution))]
expected_cols += solution_cols
expected_dtypes += [dtype for _ in solution_cols]
if include_metadata:
expected_cols.append("metadata")
expected_dtypes.append(object)
solution_dim = len(data.solution)
measure_dim = len(data.measures)
expected_cols = ([f"solution_{i}" for i in range(solution_dim)] +
["objective"] +
[f"measures_{i}" for i in range(measure_dim)] +
["metadata", "threshold", "index"])
expected_dtypes = ([dtype for _ in range(solution_dim)] + [dtype] +
[dtype for _ in range(measure_dim)] +
[object, dtype, np.int32])

# Retrieve the dataframe.
if with_elite:
df = data.archive_with_elite.as_pandas(
include_solutions=include_solutions,
include_metadata=include_metadata)
df = data.archive_with_elite.data(return_type="pandas")
else:
df = data.archive.as_pandas(include_solutions=include_solutions,
include_metadata=include_metadata)
df = data.archive.data(return_type="pandas")

# Check columns and data types.
assert (df.columns == expected_cols).all()
Expand All @@ -441,9 +432,8 @@ def test_as_pandas(name, with_elite, include_solutions, include_metadata,
assert df.loc[0, "index"] == data.archive.grid_to_int_index(
[data.grid_indices])[0]

expected_data = [*data.measures, data.objective]
if include_solutions:
expected_data += list(data.solution)
if include_metadata:
expected_data.append(data.metadata)
assert (df.loc[0, "measures_0":] == expected_data).all()
expected_data = [
*data.solution, data.objective, *data.measures, data.metadata,
data.objective
]
assert (df.loc[0, :"threshold"] == expected_data).all()

0 comments on commit 4f782db

Please sign in to comment.