Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Discrete IQL #404

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

[FEATURE] Discrete IQL #404

wants to merge 3 commits into from

Conversation

Mamba413
Copy link

I have implemented an IQL algorithm that supports discrete actions. And I have tested it in my local device and found it does work.

Below is my test code:

from d3rlpy.algos import DiscreteIQLConfig, DiscreteCQLConfig
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics import EnvironmentEvaluator

import os

os.chdir(os.path.dirname(os.path.abspath(__file__)))

def main():
    dataset, env = get_cartpole()

    iql = DiscreteIQLConfig().create(device="cpu")
    iql.build_with_dataset(dataset)
    iql.fit(
        dataset,
        n_steps=30000,
        evaluators={
            "environment": EnvironmentEvaluator(env),
        },
    )


if __name__ == "__main__":
    main()

@Mamba413
Copy link
Author

I also test it on LunarLander environment and find it surpasses DiscreteCQL when iteration is small.

from scope_rl.dataset import SyntheticDataset
from scope_rl.policy import EpsilonGreedyHead
from d3rlpy.algos import DoubleDQNConfig
from d3rlpy.dataset import create_fifo_replay_buffer
from d3rlpy.algos import ConstantEpsilonGreedy
import gym
import d3rlpy

import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))

# random state
random_state = 12345
device = "cpu"

# (0) Setup environment
env = gym.make("LunarLander-v2")

eval_env = gym.make("LunarLander-v2")

# (1) Learn a baseline policy in an online environment (using d3rlpy)
# initialize the algorithm
ddqn = DoubleDQNConfig().create(device=device)
# train an online policy
ddqn.fit_online(
    env,
    buffer=create_fifo_replay_buffer(limit=50000, env=env),
    explorer=ConstantEpsilonGreedy(epsilon=0.3),
    n_steps=1000000,
    update_start_step=10000,
    eval_env=eval_env, 
    save_interval=100000,
)
ddqn.save('ddqn_LunarLander.d3')

ddqn = d3rlpy.load_learnable('ddqn_LunarLander.d3')
behavior_policy = EpsilonGreedyHead(
    ddqn,
    n_actions=env.action_space.n,
    epsilon=0.3,
    name="ddqn_epsilon_0.3",
    random_state=random_state,
)
# initialize the dataset class
dataset = SyntheticDataset(
    env=env,
    max_episode_steps=600,
)
# the behavior policy collects some logged data
train_logged_dataset = dataset.obtain_episodes(
  behavior_policies=behavior_policy,
  n_trajectories=1000,
  random_state=random_state,
)

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import DiscreteIQLConfig, DiscreteCQLConfig
from d3rlpy.metrics import EnvironmentEvaluator

# (3) Learning a new policy from offline logged data (using d3rlpy)
# convert the logged dataset into d3rlpy's dataset format
offlinerl_dataset = MDPDataset(
    observations=train_logged_dataset["state"],
    actions=train_logged_dataset["action"],
    rewards=train_logged_dataset["reward"],
    terminals=train_logged_dataset["done"],
)
# initialize the algorithm
cql = DiscreteCQLConfig().create(device=device)
# train an offline policy
cql.fit(
    offlinerl_dataset,
    n_steps=100000,
    save_interval=10000,
    evaluators={
        "environment": EnvironmentEvaluator(env),
    },
)

cql = DiscreteIQLConfig().create(device=device)
# train an offline policy
cql.fit(
    offlinerl_dataset,
    n_steps=100000,
    save_interval=10000,
    evaluators={
        "environment": EnvironmentEvaluator(env),
    },
)

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Mamba413 Hi, thank for your contribution! I've left some comments on your changes. Apart from that, I'd like you to add a unit test to this file and make sure that the test is passed:
https://github.com/takuseno/d3rlpy/blob/master/tests/algos/qlearning/test_iql.py

Also, could you add a docstring to DiscreteIQLConfig just like here?

r"""Implicit Q-Learning algorithm.

Finally, was the discrete version of IQL explained or used in any papers? If there isn't any evidence that this works better than DQN, I'm skeptical about the necessity of DiscreteIQL.

Comment on lines 166 to 167
_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
_targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be DiscreteEnsembnleQFunctionForwarder.

Copy link
Author

@Mamba413 Mamba413 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you means DiscreteEnsembleQFunctionForwarder, I think this is solved now.

d3rlpy/algos/qlearning/torch/ddpg_impl.py Outdated Show resolved Hide resolved
d3rlpy/algos/qlearning/torch/iql_impl.py Show resolved Hide resolved
@Mamba413
Copy link
Author

Hi @takuseno , let me first answer your last comment. As you can see from Table 10 in this paper: https://arxiv.org/pdf/2303.15810, Discrete IQL (D-IQL) surpasses Discrete-CQL (D-CQL) in 2/3 tasks.

On the other hand, Discrete sparse Q learning (D-SQL) has the best performance in Table 10. As the similarity between IQL and SQL, I also glad to implement SQL with the d3rlpy package.

Finally, I will modify the code soon.

@Mamba413
Copy link
Author

By the way, I believe the implementation of discrete-IQL can be further improved. Current implementation uses a stochastic policy that have to be updated; however, this update actually can be avoided like Discrete-CQL so as to gain higher computational efficiency. I haven't implemented such a more quick version as I feel this implementation is more complicated and I am not sufficiently understand the entire software design.

@takuseno
Copy link
Owner

takuseno commented Jul 20, 2024

https://arxiv.org/pdf/2303.15810, Discrete IQL (D-IQL) surpasses Discrete-CQL (D-CQL) in 2/3 tasks.

Ah, I didn't know that! Thank you for sharing this. Now, I'm happy to include DiscreteIQL (it'd be even nicer if you could add SQL as well 😉 ). I'm looking forward to the fix you're working on. Btw, the format check in CI complains about your change. Could you also try this before you finalize your PR?

pip install -r dev.requirements.txt
./scripts/format
./scripts/lint

Thanks!

@takuseno
Copy link
Owner

By the way, I believe the implementation of discrete-IQL can be further improved. Current implementation uses a stochastic policy that have to be updated; however, this update actually can be avoided like Discrete-CQL so as to gain higher computational efficiency. I haven't implemented such a more quick version as I feel this implementation is more complicated and I am not sufficiently understand the entire software design.

Please do not worry about this. If there is a way to optimize your code, I can do that on my side.

@Mamba413
Copy link
Author

https://arxiv.org/pdf/2303.15810, Discrete IQL (D-IQL) surpasses Discrete-CQL (D-CQL) in 2/3 tasks.

Ah, I didn't know that! Thank you for sharing this. Now, I'm happy to include DiscreteIQL (it'd be even nicer if you could add SQL as well 😉 ). I'm looking forward to the fix you're working on. Btw, the format check in CI complains about your change. Could you also try this before you finalize your PR?

pip install -r dev.requirements.txt
./scripts/format
./scripts/lint

Thanks!

I just update the code following your previous comment. I still have an unsolved problem when I conduct:

./scripts/lint

I found it returns many error:

tests/preprocessing/test_base.py:16: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_trajectory_slicer.py:58: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_trajectory_slicer.py:59: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_trajectory_slicer.py:145: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_trajectory_slicer.py:146: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_mini_batch.py:95: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_mini_batch.py:96: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_mini_batch.py:97: error: Unused "type: ignore" comment  [unused-ignore]
tests/dataset/test_mini_batch.py:98: error: Unused "type: ignore" comment  [unused-ignore]
d3rlpy/algos/qlearning/torch/ddpg_impl.py:246: error: "ActionOutput" has no attribute "probs"  [attr-defined]
tests/algos/qlearning/test_random_policy.py:50: error: Unused "type: ignore" comment  [unused-ignore]
tests/algos/qlearning/test_random_policy.py:55: error: Unused "type: ignore" comment  [unused-ignore]
tests/envs/test_wrappers.py:29: error: Unused "type: ignore" comment  [unused-ignore]
tests/envs/test_wrappers.py:33: error: Unused "type: ignore" comment  [unused-ignore]
tests/envs/test_wrappers.py:51: error: Unused "type: ignore" comment  [unused-ignore]
tests/envs/test_wrappers.py:55: error: Unused "type: ignore" comment  [unused-ignore]

I have already addressed some of them but it still not clear how to address this line:

d3rlpy/algos/qlearning/torch/ddpg_impl.py:246: error: "ActionOutput" has no attribute "probs"  [attr-defined]

as it would make my implemented code has a lot of change and I am not sure whether these changes still make my code work.

Besides, I feel the following error message do not come from my modification? As I haven't modified test_wrappers.py file.

tests/envs/test_wrappers.py:55: error: Unused "type: ignore" comment  [unused-ignore]

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! It seems that there is something wrong with mypy tests, but I'll follow up on that after we merge this PR. One thing I need you to do here is to remove q_func_factory from DiscreteIQLConfig and use MeanQFunctionFactory. This is because we can't really change Q-function types due to the state-value function in IQL.

@pytest.mark.parametrize("scalers", [None, "min_max"])
def test_discrete_iql(
observation_shape: Shape,
q_func_factory: QFunctionFactory,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove q_func_factory here?

observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove q_func_factory from config? Instead, please use MeanQFunctionFactory just like the continuous IQL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants