From ea60d3962aa6ee89e7caa73bf0da28d2dca70e2d Mon Sep 17 00:00:00 2001 From: Jungsub Lim Date: Mon, 25 Mar 2019 00:55:12 +0900 Subject: [PATCH] Update 4.prioritized dqn.ipynb removed volatile parameter and added wrapper @torch.no_grad() zip(*samples) -> list(zip(*samples)) .data[0] -> .item() . # 2 cases --- 4.prioritized dqn.ipynb | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/4.prioritized dqn.ipynb b/4.prioritized dqn.ipynb index 04dcfc7..17d3c15 100644 --- a/4.prioritized dqn.ipynb +++ b/4.prioritized dqn.ipynb @@ -106,7 +106,7 @@ " weights /= weights.max()\n", " weights = np.array(weights, dtype=np.float32)\n", " \n", - " batch = zip(*samples)\n", + " batch = list(zip(*samples))\n", " states = np.concatenate(batch[0])\n", " actions = batch[1]\n", " rewards = batch[2]\n", @@ -264,7 +264,7 @@ " if random.random() > epsilon:\n", " state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True)\n", " q_value = self.forward(state)\n", - " action = q_value.max(1)[1].data[0]\n", + " action = q_value.max(1)[1].item()\n", " else:\n", " action = random.randrange(env.action_space.n)\n", " return action" @@ -425,7 +425,7 @@ " if len(replay_buffer) > batch_size:\n", " beta = beta_by_frame(frame_idx)\n", " loss = compute_td_loss(batch_size, beta)\n", - " losses.append(loss.data[0])\n", + " losses.append(loss.item())\n", " \n", " if frame_idx % 200 == 0:\n", " plot(frame_idx, all_rewards, losses)\n", @@ -506,9 +506,10 @@ " def feature_size(self):\n", " return self.features(autograd.Variable(torch.zeros(1, *self.input_shape))).view(1, -1).size(1)\n", " \n", + " @torch.no_grad() " def act(self, state, epsilon):\n", " if random.random() > epsilon:\n", - " state = Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0), volatile=True)\n", + " state = Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0),)\n", " q_value = self.forward(state)\n", " action = q_value.max(1)[1].data[0]\n", " else:\n",