-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Continous CQL loss logging and aligning with discrete logging #317
Changes from 5 commits
4a6edc9
5b7185c
7fd0a37
a52651e
a46d73f
7255159
294544d
b8263d4
3ec25d7
bae6777
817810d
4ba297f
2d730ef
9d4f928
0874990
2846d20
8e5aec8
fa90b6c
b1c1e62
5bd59a1
80b9579
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,23 @@ def compute_critic_loss( | |
conservative_loss = self._compute_conservative_loss( | ||
batch.observations, batch.actions, batch.next_observations | ||
) | ||
return loss + conservative_loss | ||
return loss + conservative_loss, conservative_loss | ||
|
||
@train_api | ||
def update_critic(self, batch: TorchMiniBatch) -> np.array: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is altering a signature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The interface has been updated in this commit: 67723be . Please check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @takuseno - I'm sorry for the delay! All looks good to me except my one comment on the AWAC d3rlpy/algos/qlearning/awac.py file. Cheers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @takuseno is there anything else you need from me at all for the PR? :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @joshuaspear Thanks for the contribution. I've changed the target branch to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @takuseno no probs - will do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @takuseno have merged the branches - I included a couple more data classes for the loss outputs FYI |
||
self._critic_optim.zero_grad() | ||
|
||
q_tpn = self.compute_target(batch) | ||
|
||
loss, cql_loss = self.compute_critic_loss(batch, q_tpn) | ||
|
||
loss.backward() | ||
self._critic_optim.step() | ||
|
||
critic_loss = float(loss.cpu().detach().numpy()) | ||
cql_loss = float(cql_loss.cpu().detach().numpy()) | ||
res = np.array([critic_loss, cql_loss]) | ||
return res | ||
|
||
@train_api | ||
def update_alpha(self, batch: TorchMiniBatch) -> Tuple[float, float]: | ||
|
@@ -221,7 +237,8 @@ def compute_loss( | |
conservative_loss = self._compute_conservative_loss( | ||
batch.observations, batch.actions.long() | ||
) | ||
return loss + self._alpha * conservative_loss, conservative_loss | ||
cql_loss = self._alpha * conservative_loss | ||
return loss + cql_loss, cql_loss | ||
|
||
def _compute_conservative_loss( | ||
self, obs_t: torch.Tensor, act_t: torch.Tensor | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running tests, I seem to get .coverage... files with references to my local system. Maybe it's the way I'm running the test?