-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Continous CQL loss logging and aligning with discrete logging #317
Conversation
Codecov Report
@@ Coverage Diff @@
## refactor_loss #317 +/- ##
=================================================
+ Coverage 92.72% 92.91% +0.18%
=================================================
Files 108 108
Lines 7353 7109 -244
=================================================
- Hits 6818 6605 -213
+ Misses 535 504 -31
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! As I commented below, I'll take the first pass to refactor signatures of methods. Then, I'll let you known when it's done.
@@ -14,7 +14,7 @@ docs/d3rlpy*.rst | |||
docs/modules.rst | |||
docs/references/generated | |||
coverage.xml | |||
.coverage | |||
.coverage* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running tests, I seem to get .coverage... files with references to my local system. Maybe it's the way I'm running the test?
return loss + conservative_loss, conservative_loss | ||
|
||
@train_api | ||
def update_critic(self, batch: TorchMiniBatch) -> np.array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is altering a signature of update_critic
, which is def update_critic(self, batch: TorchMiniBatch) -> float:
. Thus, we actually need to refactor these methods first. I'll take the first pass to return Dict[str, float]
. Then, you can make changes on top of it. Sorry for the inconvenience. I'll let you know when it's done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface has been updated in this commit: 67723be . Please check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuseno - I'm sorry for the delay! All looks good to me except my one comment on the AWAC d3rlpy/algos/qlearning/awac.py file. Cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuseno is there anything else you need from me at all for the PR? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joshuaspear Thanks for the contribution. I've changed the target branch to refactor_loss
because I've made some new changes to master branch. I'm seeing conflicts between your PR and refactor_loss
branch. Could you resolve it? Here is an example instruction to merge refactor_loss
branch to your branch.
$ git fetch upstream
$ git checkout master
$ git merge upstream/refactor_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuseno no probs - will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuseno have merged the branches - I included a couple more data classes for the loss outputs FYI
Fixed a small typo. Many thanks again!
…impl. Also 1.Updated the conservative loss of discrete cql to be captured including the alpha multiplication to align with continuous cql. 2. Updated the critic loss of ddpg and continuous CQL to use dataclasses - aligning with DQN and discrete cql
@joshuaspear Thank you for continuing this, but, I'm seeing weirdly large number of changes in your diff now.... It could be easier to close this PR and make a new one based on the latest master to resolve this.... |
@takuseno makes sense :) Will have a crack at it next week |
Implementing logging for the conservative loss part of CQL for continuous CQL. Also altered the logging of the loss in the discrete model to include the value of$\alpha$