Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continous CQL loss logging and aligning with discrete logging #317

Closed
wants to merge 21 commits into from

Conversation

joshuaspear
Copy link
Contributor

Implementing logging for the conservative loss part of CQL for continuous CQL. Also altered the logging of the loss in the discrete model to include the value of $\alpha$

@codecov
Copy link

codecov bot commented Aug 10, 2023

Codecov Report

Merging #317 (80b9579) into refactor_loss (36222ae) will increase coverage by 0.18%.
Report is 12 commits behind head on refactor_loss.
The diff coverage is 97.73%.

@@                Coverage Diff                @@
##           refactor_loss     #317      +/-   ##
=================================================
+ Coverage          92.72%   92.91%   +0.18%     
=================================================
  Files                108      108              
  Lines               7353     7109     -244     
=================================================
- Hits                6818     6605     -213     
+ Misses               535      504      -31     
Files Changed Coverage Δ
d3rlpy/metrics/evaluators.py 97.15% <ø> (ø)
d3rlpy/models/encoders.py 97.01% <ø> (+1.61%) ⬆️
d3rlpy/models/torch/q_functions/__init__.py 100.00% <ø> (ø)
d3rlpy/algos/transformer/base.py 93.15% <77.77%> (-0.52%) ⬇️
d3rlpy/algos/qlearning/base.py 88.21% <81.25%> (-0.15%) ⬇️
d3rlpy/algos/qlearning/cql.py 97.70% <88.23%> (-2.30%) ⬇️
d3rlpy/algos/qlearning/sac.py 97.67% <88.23%> (-2.33%) ⬇️
d3rlpy/models/torch/q_functions/base.py 80.00% <90.00%> (+0.51%) ⬆️
d3rlpy/algos/qlearning/torch/plas_impl.py 92.94% <91.07%> (+0.40%) ⬆️
d3rlpy/algos/qlearning/torch/bc_impl.py 92.64% <91.17%> (+9.04%) ⬆️
... and 43 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR! As I commented below, I'll take the first pass to refactor signatures of methods. Then, I'll let you known when it's done.

@@ -14,7 +14,7 @@ docs/d3rlpy*.rst
docs/modules.rst
docs/references/generated
coverage.xml
.coverage
.coverage*
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running tests, I seem to get .coverage... files with references to my local system. Maybe it's the way I'm running the test?

return loss + conservative_loss, conservative_loss

@train_api
def update_critic(self, batch: TorchMiniBatch) -> np.array:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is altering a signature of update_critic, which is def update_critic(self, batch: TorchMiniBatch) -> float:. Thus, we actually need to refactor these methods first. I'll take the first pass to return Dict[str, float]. Then, you can make changes on top of it. Sorry for the inconvenience. I'll let you know when it's done.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface has been updated in this commit: 67723be . Please check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno - I'm sorry for the delay! All looks good to me except my one comment on the AWAC d3rlpy/algos/qlearning/awac.py file. Cheers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno is there anything else you need from me at all for the PR? :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joshuaspear Thanks for the contribution. I've changed the target branch to refactor_loss because I've made some new changes to master branch. I'm seeing conflicts between your PR and refactor_loss branch. Could you resolve it? Here is an example instruction to merge refactor_loss branch to your branch.

$ git fetch upstream
$ git checkout master
$ git merge upstream/refactor_loss

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno no probs - will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuseno have merged the branches - I included a couple more data classes for the loss outputs FYI

@takuseno takuseno changed the base branch from master to refactor_loss August 26, 2023 01:55
…impl. Also 1.Updated the conservative loss of discrete cql to be captured including the alpha multiplication to align with continuous cql. 2. Updated the critic loss of ddpg and continuous CQL to use dataclasses - aligning with DQN and discrete cql
@takuseno
Copy link
Owner

takuseno commented Sep 2, 2023

@joshuaspear Thank you for continuing this, but, I'm seeing weirdly large number of changes in your diff now.... It could be easier to close this PR and make a new one based on the latest master to resolve this....

@joshuaspear
Copy link
Contributor Author

@takuseno makes sense :) Will have a crack at it next week

@takuseno takuseno deleted the branch takuseno:refactor_loss December 2, 2023 08:33
@takuseno takuseno closed this Dec 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants