From 3af5eeb99daad519c58dbe2c673bd6d07f72963e Mon Sep 17 00:00:00 2001 From: "K.R. Zentner" Date: Sat, 16 Apr 2022 12:57:51 -0700 Subject: [PATCH] Support GPUs in RaySampler --- src/garage/sampler/ray_sampler.py | 35 ++++++++++++++----- src/garage/sampler/sampler.py | 18 +--------- .../sampler/test_ray_batched_sampler.py | 35 +++++++++++++++++++ 3 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/garage/sampler/ray_sampler.py b/src/garage/sampler/ray_sampler.py index 33ae011382..d5540e1853 100644 --- a/src/garage/sampler/ray_sampler.py +++ b/src/garage/sampler/ray_sampler.py @@ -37,11 +37,17 @@ class RaySampler(Sampler): The maximum length episodes which will be sampled. is_tf_worker (bool): Whether it is workers for TFTrainer. seed(int): The seed to use to initialize random number generators. - n_workers(int): The number of workers to use. + n_workers(int or None): The number of workers to use. Defaults to + number of physical cpus, if worker_factory is also None. worker_class(type): Class of the workers. Instances should implement the Worker interface. worker_args (dict or None): Additional arguments that should be passed to the worker. + n_gpus (int or float): Number of GPUs to to use in total for sampling. + If `n_workers` is not a power of two, this may need to be set + slightly below the true value, since `n_workers / n_gpus` gpus are + allocated to each worker. Defaults to zero, because otherwise + nothing would run if no gpus were available. """ @@ -54,18 +60,22 @@ def __init__( max_episode_length=None, is_tf_worker=False, seed=get_seed(), - n_workers=psutil.cpu_count(logical=False), + n_workers=None, worker_class=DefaultWorker, - worker_args=None): - # pylint: disable=super-init-not-called + worker_args=None, + n_gpus=0): if not ray.is_initialized(): ray.init(log_to_driver=False, ignore_reinit_error=True) if worker_factory is None and max_episode_length is None: raise TypeError('Must construct a sampler from WorkerFactory or' 'parameters (at least max_episode_length)') - if isinstance(worker_factory, WorkerFactory): + if worker_factory is not None: + if n_workers is None: + n_workers = worker_factory.n_workers self._worker_factory = worker_factory else: + if n_workers is None: + n_workers = psutil.cpu_count(logical=False) self._worker_factory = WorkerFactory( max_episode_length=max_episode_length, is_tf_worker=is_tf_worker, @@ -73,7 +83,9 @@ def __init__( n_workers=n_workers, worker_class=worker_class, worker_args=worker_args) - self._sampler_worker = ray.remote(SamplerWorker) + remote_wrapper = ray.remote(num_gpus=n_gpus / n_workers) + self._n_gpus = n_gpus + self._sampler_worker = remote_wrapper(SamplerWorker) self._agents = agents self._envs = self._worker_factory.prepare_worker_messages(envs) self._all_workers = defaultdict(None) @@ -103,7 +115,10 @@ def from_worker_factory(cls, worker_factory, agents, envs): Sampler: An instance of `cls`. """ - return cls(agents, envs, worker_factory=worker_factory) + return cls(agents, + envs, + worker_factory=worker_factory, + n_workers=worker_factory.n_workers) def start_worker(self): """Initialize a new ray worker.""" @@ -308,7 +323,8 @@ def __getstate__(self): """ return dict(factory=self._worker_factory, agents=self._agents, - envs=self._envs) + envs=self._envs, + n_gpus=self._n_gpus) def __setstate__(self, state): """Unpickle the state. @@ -319,7 +335,8 @@ def __setstate__(self, state): """ self.__init__(state['agents'], state['envs'], - worker_factory=state['factory']) + worker_factory=state['factory'], + n_gpus=state['n_gpus']) class SamplerWorker: diff --git a/src/garage/sampler/sampler.py b/src/garage/sampler/sampler.py index 1c888edd26..7ecf052edb 100644 --- a/src/garage/sampler/sampler.py +++ b/src/garage/sampler/sampler.py @@ -13,25 +13,9 @@ class Sampler(abc.ABC): `Sampler` needs. Specifically, it specifies how to construct `Worker`s, which know how to collect episodes and update both agents and environments. - Currently, `__init__` is also part of the interface, but calling it is - deprecated. `start_worker` is also deprecated, and does not need to be - implemented. + `start_worker` is deprecated, and does not need to be implemented. """ - def __init__(self, algo, env): - """Construct a Sampler from an Algorithm. - - Args: - algo (RLAlgorithm): The RL Algorithm controlling this - sampler. - env (Environment): The environment being sampled from. - - Calling this method is deprecated. - - """ - self.algo = algo - self.env = env - @classmethod def from_worker_factory(cls, worker_factory, agents, envs): """Construct this sampler. diff --git a/tests/garage/sampler/test_ray_batched_sampler.py b/tests/garage/sampler/test_ray_batched_sampler.py index 80ff2be263..a0ca485155 100644 --- a/tests/garage/sampler/test_ray_batched_sampler.py +++ b/tests/garage/sampler/test_ray_batched_sampler.py @@ -1,4 +1,5 @@ """Tests for ray_batched_sampler.""" +import pickle from unittest.mock import Mock import numpy as np @@ -138,6 +139,40 @@ def test_init_with_env_updates(ray_local_session_fixture): assert sum(episodes.lengths) >= 160 +def test_pickle(ray_local_session_fixture): + del ray_local_session_fixture + assert ray.is_initialized() + max_episode_length = 16 + env = PointEnv() + policy = FixedPolicy(env.spec, + scripted_actions=[ + env.action_space.sample() + for _ in range(max_episode_length) + ]) + tasks = SetTaskSampler(PointEnv) + n_workers = 4 + workers = WorkerFactory(seed=100, + max_episode_length=max_episode_length, + n_workers=n_workers) + sampler = RaySampler.from_worker_factory(workers, policy, env) + sampler_pickled = pickle.dumps(sampler) + sampler.shutdown_worker() + sampler2 = pickle.loads(sampler_pickled) + episodes = sampler2.obtain_samples(0, + 500, + np.asarray(policy.get_param_values()), + env_update=tasks.sample(n_workers)) + mean_rewards = [] + goals = [] + for eps in episodes.split(): + mean_rewards.append(eps.rewards.mean()) + goals.append(eps.env_infos['task'][0]['goal']) + assert np.var(mean_rewards) > 0 + assert np.var(goals) > 0 + sampler2.shutdown_worker() + env.close() + + def test_init_without_worker_factory(ray_local_session_fixture): del ray_local_session_fixture assert ray.is_initialized()