From c6c581538f84091238192d5db555cd661efcee86 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Wed, 8 Jan 2025 22:05:57 +0000 Subject: [PATCH] Revert `OrderedDict` key ordering in `Dict` space (#1291) --- gymnasium/spaces/dict.py | 6 ++++-- tests/spaces/test_dict.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/gymnasium/spaces/dict.py b/gymnasium/spaces/dict.py index 49ff4c907..709d8be15 100644 --- a/gymnasium/spaces/dict.py +++ b/gymnasium/spaces/dict.py @@ -4,6 +4,7 @@ import collections.abc import typing +from collections import OrderedDict from typing import Any, KeysView, Sequence import numpy as np @@ -66,8 +67,9 @@ def __init__( seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space. **spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above. """ - # Convert the spaces into an OrderedDict - if isinstance(spaces, collections.abc.Mapping): + if isinstance(spaces, OrderedDict): + spaces = dict(spaces.items()) + elif isinstance(spaces, collections.abc.Mapping): # for legacy reasons, we need to preserve the sorted dictionary items. # as this could matter for projects flatten the dictionary. try: diff --git a/tests/spaces/test_dict.py b/tests/spaces/test_dict.py index 54f617543..cc8483c16 100644 --- a/tests/spaces/test_dict.py +++ b/tests/spaces/test_dict.py @@ -37,8 +37,22 @@ def test_dict_init(): assert a == b == c == d assert len(caught_warnings) == 0 + # test sorting with warnings.catch_warnings(record=True) as caught_warnings: - Dict({1: Discrete(2), "a": Discrete(3)}) + # Sorting is applied to the keys + a = Dict({"b": Box(low=0.0, high=1.0), "a": Discrete(2)}) + assert a.keys() == {"a", "b"} + + # Sorting is not applied to the keys + b = Dict(OrderedDict(b=Box(low=0.0, high=1.0), a=Discrete(2))) + c = Dict((("b", Box(low=0.0, high=1.0)), ("a", Discrete(2)))) + d = Dict(b=Box(low=0.0, high=1.0), a=Discrete(2)) + assert b.keys() == c.keys() == d.keys() == {"b", "a"} + assert len(caught_warnings) == 0 + + # test sorting with different classes + with warnings.catch_warnings(record=True) as caught_warnings: + assert Dict({1: Discrete(2), "a": Discrete(3)}).keys() == {1, "a"} assert len(caught_warnings) == 0