Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zms): add new league middlewares and other models and tools. #458

Open
wants to merge 288 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
288 commits
Select commit Hold shift + click to select a range
b404024
change a bit
hiha3456 May 28, 2022
ffcc50b
change a bit
hiha3456 May 28, 2022
a55fe17
change a bit
hiha3456 May 28, 2022
8084dae
change a bit
hiha3456 May 28, 2022
da46b22
change a bit
hiha3456 May 28, 2022
65b8762
change a bit
hiha3456 May 28, 2022
ad9632c
change position of policy_resetter, battle_rolloutor, battle_inferencer
hiha3456 May 28, 2022
192e6e4
polish stdim & infonce loss
lixl-st May 30, 2022
46ee7cf
polish EventEnum
lixl-st May 30, 2022
654eb9f
Merge branch 'main' of https://github.com/opendilab/DI-engine into de…
lixl-st May 30, 2022
48ddc09
polish EventEnum
lixl-st May 30, 2022
474524e
simplify the code
hiha3456 May 28, 2022
9d91063
for multiple policies
hiha3456 May 28, 2022
6ca4ae2
change a bit
hiha3456 May 28, 2022
2d85ccf
fix style
lixl-st May 30, 2022
157b9b3
revert to drop the streaming type code
hiha3456 May 30, 2022
a13b5e6
fix codecov
lixl-st May 30, 2022
9a25055
Merge branch 'main' of https://github.com/opendilab/DI-engine
lixl-st May 30, 2022
ecd0119
fix import
lixl-st May 30, 2022
4d54abf
add readme
lixl-st May 30, 2022
65525c1
drop useless codes
hiha3456 May 30, 2022
a1b0908
change rolloutor
hiha3456 May 30, 2022
68191a9
move on_league_job in call
hiha3456 May 30, 2022
59f2a6e
change __call__
hiha3456 May 31, 2022
2010965
solve conflicts
lixl-st May 31, 2022
1382b7f
solve conflicts
lixl-st May 31, 2022
b546f8a
Merge branch 'dev-distar-actor' of https://github.com/opendilab/DI-en…
lixl-st May 31, 2022
cfabb47
add league coordinator
lixl-st May 31, 2022
0ff864d
reformatting
hiha3456 May 31, 2022
4cf60f9
Merge branch 'main' into dev-distar-actor
hiha3456 May 31, 2022
3e71184
Merge branch 'dev-distar-actor' into dev-league-lxl
hiha3456 May 31, 2022
3f0dcb6
Merge pull request #345 from lixl-st/dev-league-lxl
hiha3456 May 31, 2022
6542a30
add new Enum event names inside actor and test
hiha3456 May 31, 2022
c7fac3f
add test coordinator & actor pipeline
hiha3456 May 31, 2022
49ccd6c
modify _on_learner_model method inside league_actor to make it thread…
hiha3456 Jun 1, 2022
6a2b76c
change test_league_pipeline.py to multiple process
hiha3456 Jun 1, 2022
aeac530
fix a problem of env initialization inside the creation of collector
hiha3456 Jun 1, 2022
0cd0e1d
polish league coordinator
lixl-st Jun 1, 2022
c0d5dd0
Merge branch 'dev-distar-actor' of https://github.com/opendilab/DI-en…
lixl-st Jun 1, 2022
fe26bb3
change "policies" to "current_policies" to make it clear
hiha3456 Jun 1, 2022
bd6395b
polish league coordinator
lixl-st Jun 1, 2022
f0601c6
Merge pull request #346 from lixl-st/dev-league-lxl
hiha3456 Jun 1, 2022
465ce9b
add streaming data collection
hiha3456 Jun 1, 2022
7ff8a0a
Merge branch 'dev-distar-actor' of https://github.com/opendilab/DI-en…
hiha3456 Jun 1, 2022
a6e9343
demo(nyz): add distar model
PaParaZz1 Jun 1, 2022
8093b72
feature(nyz): polish encoder and value related modules (ci skip)
PaParaZz1 Jun 1, 2022
83abc64
add locker and model_dict
hiha3456 Jun 1, 2022
604ed59
change name of vars and add BattleContext
hiha3456 Jun 1, 2022
2ad937d
debugging
hiha3456 Jun 1, 2022
bfd149d
update quick colab link
lixl-st Jun 2, 2022
b13d49e
Merge branch 'main' into main
lixl-st Jun 2, 2022
1071dd1
drop commit "add BattleContext, policy_getter, policy_updater, modify…
hiha3456 Jun 2, 2022
310c279
change the logic of update model
hiha3456 Jun 2, 2022
d8a385d
add actor._get_current_policies and collector._update_policies
hiha3456 Jun 2, 2022
0795098
change variable names
hiha3456 Jun 6, 2022
8d295c0
Merge branch 'dev-league-lxl' of github.com:lixl-st/DI-engine into de…
lixl-st Jun 6, 2022
c4397b2
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 6, 2022
d0df49e
merge from dev-distar-learn
lixl-st Jun 6, 2022
ab7b556
add league policy
lixl-st Jun 6, 2022
1137a17
reformat
lixl-st Jun 6, 2022
400fa60
Merge pull request #349 from opendilab/main
hiha3456 Jun 6, 2022
e9d3b47
polish(nyz): polish and test distar head
PaParaZz1 Jun 6, 2022
fd71c46
add learner to test pipeline
lixl-st Jun 7, 2022
77dc67e
change vars names, get rid of cache_pool
hiha3456 Jun 8, 2022
f95113a
change the position of traj_buffers initialization
hiha3456 Jun 8, 2022
7d5c112
change a bit rolloutor
hiha3456 Jun 8, 2022
bc8c332
change the style to 1.0 middleware
hiha3456 Jun 8, 2022
1078673
change a bit
hiha3456 Jun 8, 2022
d3fcc51
change a bit
hiha3456 Jun 8, 2022
0dc1c5f
change variable name
hiha3456 Jun 8, 2022
5944584
add step collector and step actor
hiha3456 Jun 8, 2022
c171d59
feature&optim(zzh): add DDPPO & add model-based SAC with lambda-retur…
ZHZisZZ Jun 5, 2022
2859423
add league learner
lixl-st Jun 9, 2022
2aa05da
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 9, 2022
2c46dbb
solve conflicts
lixl-st Jun 9, 2022
db3afc2
change the distar_env to fit BaseEnvManager, write test_distar_env_wi…
hiha3456 Jun 9, 2022
3dc0b12
arrage tests
hiha3456 Jun 10, 2022
9b4deaa
format context and league_actor
hiha3456 Jun 10, 2022
c65fecf
make old test runnable
hiha3456 Jun 10, 2022
cd18955
change the distar_env to fit BaseEnvManager, write test_distar_env_wi…
hiha3456 Jun 9, 2022
6a3d433
arrage tests
hiha3456 Jun 10, 2022
e3c95ff
make old test runnable
hiha3456 Jun 10, 2022
34e6374
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 10, 2022
8027760
change random_action to a classmethod
hiha3456 Jun 13, 2022
cbe964c
change a bit DI-star env
hiha3456 Jun 13, 2022
fa07d12
add mock policy
lixl-st Jun 13, 2022
7ad26ef
reformat
lixl-st Jun 13, 2022
56ed8e1
fix a bug
hiha3456 Jun 13, 2022
b934260
modification to run DI-star in pipeline
hiha3456 Jun 13, 2022
400a19d
Merge pull request #350 from lixl-st/dev-league-lxl
hiha3456 Jun 13, 2022
e13a5f2
merge dev-distar branch and fix conflicts
hiha3456 Jun 13, 2022
3518a03
fix conflicts
hiha3456 Jun 13, 2022
9ae914a
change mocks
hiha3456 Jun 13, 2022
2d050d3
one_process test pass, multiple_process test failed because cannot pi…
hiha3456 Jun 13, 2022
82af279
change config
hiha3456 Jun 13, 2022
d7faf35
Merge pull request #361 from opendilab/dev-distar-collector
hiha3456 Jun 13, 2022
1e78af8
transform transitions so they could be sent, and write responding tes…
hiha3456 Jun 14, 2022
ece3a64
feature(nyz): add distar policy learn part
PaParaZz1 Jun 14, 2022
e9e8ebb
change format
hiha3456 Jun 14, 2022
f9f6ef6
Merge pull request #365 from opendilab/dev-distar-collector
hiha3456 Jun 14, 2022
22c8bcd
Merge branch 'dev-distar-learn' of github.com:opendilab/DI-engine int…
lixl-st Jun 15, 2022
d739635
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 15, 2022
ffe8f7c
change print to show node_id; change format
hiha3456 Jun 15, 2022
b6c24b7
handle exception during reset SC2 env
hiha3456 Jun 15, 2022
78bed49
rolloutor handle error during step
hiha3456 Jun 15, 2022
2ad4850
add test for step exception
hiha3456 Jun 15, 2022
5685f3b
Merge pull request #372 from opendilab/dev-distar-collector
hiha3456 Jun 15, 2022
72dfa38
fix a bug of TransitionList, and simpify BattleContext
hiha3456 Jun 15, 2022
f6ec222
change a bit
hiha3456 Jun 15, 2022
b6129ed
change var agent_num
hiha3456 Jun 15, 2022
c16e9e7
make n_episode more clear
hiha3456 Jun 15, 2022
6fa1209
get rid of ctx.job
hiha3456 Jun 15, 2022
09b7145
polish(nyz): remove whole_cfg in distar model and fix bugs
PaParaZz1 Jun 15, 2022
06fee68
Merge pull request #375 from opendilab/dev-distar-collector
hiha3456 Jun 16, 2022
6d653a1
add timer
hiha3456 Jun 16, 2022
dc718dc
Merge pull request #376 from opendilab/dev-distar-collector
hiha3456 Jun 16, 2022
5d422d6
add log
hiha3456 Jun 16, 2022
c0d5a52
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 16, 2022
ea3269d
Merge pull request #377 from opendilab/dev-distar-collector
hiha3456 Jun 16, 2022
1746905
add walltime data, change ActorData Structure
hiha3456 Jun 17, 2022
de5f8bb
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
lixl-st Jun 17, 2022
a30e23b
battle_transition_list, need to change list to deque
hiha3456 Jun 17, 2022
d7a2563
test(nyz): add distar policy learn unittest
PaParaZz1 Jun 19, 2022
81502a0
merge
lixl-st Jun 20, 2022
b4dcce0
merge
lixl-st Jun 20, 2022
8f93810
merge main into dev-distar
hiha3456 Jun 20, 2022
6d89d09
Merge branch 'dev-distar' of github.com:opendilab/DI-engine into dev-…
hiha3456 Jun 20, 2022
96debf7
Merge branch 'dev-distar' into dev-distar-collector
hiha3456 Jun 20, 2022
ca0987b
change a bit
hiha3456 Jun 20, 2022
006132f
change a bit
hiha3456 Jun 20, 2022
c664f60
add comment to BattleTransitionList
hiha3456 Jun 20, 2022
78bfd97
change a bit
hiha3456 Jun 20, 2022
b64a3ec
add BattleTransitionList into league_actors, but have some unexpected…
hiha3456 Jun 20, 2022
8d30256
add league learner exchanger
lixl-st Jun 21, 2022
0bd6e4a
deal with step error
hiha3456 Jun 21, 2022
338a2a2
change a bit env_supervisor so it could run DI-star env
hiha3456 Jun 21, 2022
65e3fb1
polish pipeline & add distar example
lixl-st Jun 21, 2022
a9f1fc0
Merge branch 'dev-distar' into dev-league-lxl
lixl-st Jun 21, 2022
805d219
make pipelines run in supervisor
hiha3456 Jun 21, 2022
917a6e7
Merge pull request #388 from lixl-st/dev-league-lxl
hiha3456 Jun 21, 2022
716c55b
fix conflicts
hiha3456 Jun 21, 2022
87ac3ce
reformat
hiha3456 Jun 21, 2022
53bfdf5
feature(nyz): add basic distar policy collect(ci skip)
PaParaZz1 Jun 21, 2022
d467140
change init to make test_pipeline runnable
hiha3456 Jun 21, 2022
7ea23f6
adjust league_learner but cannot run
hiha3456 Jun 21, 2022
20d4c2b
add notes in conference
hiha3456 Jun 22, 2022
984735f
Merge pull request #392 from opendilab/dev-distar-collector
hiha3456 Jun 22, 2022
2e7fb1a
merge
lixl-st Jun 22, 2022
e50fc12
Merge branch 'dev-distar' into dev-distar-learn
hiha3456 Jun 22, 2022
31cd26a
Merge pull request #393 from opendilab/dev-distar-learn
hiha3456 Jun 22, 2022
eafdf9c
change commit position
hiha3456 Jun 22, 2022
c989cc0
Merge branch 'dev-distar' into dev-distar-collector
hiha3456 Jun 22, 2022
a061243
add z infos
hiha3456 Jun 22, 2022
f44f54d
adjust codes to run the pipeline
hiha3456 Jun 22, 2022
763b627
Merge pull request #396 from opendilab/dev-distar-fix-bug
hiha3456 Jun 23, 2022
2074302
merge
lixl-st Jun 23, 2022
9c0d6f3
polish(nyz): polish parse_new_game and add transform_obs
PaParaZz1 Jun 23, 2022
37156d4
drop get config
hiha3456 Jun 23, 2022
c923c23
drop useless remain_episode and ready_env_ids
hiha3456 Jun 23, 2022
5db2b52
feature(zms): remove the episodes shorter than unroll_len
hiha3456 Jun 23, 2022
63ac2df
get game_info, map_name, map_size inside DIStarEnv.reset()
hiha3456 Jun 23, 2022
c1cc5af
final_eval_reward
hiha3456 Jun 23, 2022
6aefbe2
change a bit
hiha3456 Jun 23, 2022
b0a1051
add logging
lixl-st Jun 23, 2022
4ef8ea8
rm exp files
lixl-st Jun 23, 2022
7d3a152
add info["result"]
hiha3456 Jun 23, 2022
c7e0b4c
polish
lixl-st Jun 23, 2022
7137660
Merge branch 'dev-distar' into dev-league-lxl
hiha3456 Jun 23, 2022
61c3b86
Merge pull request #398 from lixl-st/dev-league-lxl
hiha3456 Jun 23, 2022
723cbcc
remove rep
hiha3456 Jun 23, 2022
0bc4b20
add comment
hiha3456 Jun 24, 2022
02b9d0b
add result info in job.info
hiha3456 Jun 24, 2022
b72cb78
comment Episode Actor and Episode Collector
hiha3456 Jun 24, 2022
e12a82e
move battle_inferencer_for_distar, battle_rolloutor_for_distar to fun…
hiha3456 Jun 24, 2022
cb97628
remove rep
hiha3456 Jun 24, 2022
19aae09
change to fix bug on k8s
hiha3456 Jun 24, 2022
e9aad7e
change a bit data_processor.py
hiha3456 Jun 27, 2022
98effdc
make old tests could run
hiha3456 Jun 27, 2022
2d71278
feature(zms): check for 60s if get new model or not
hiha3456 Jun 27, 2022
94d750d
Merge pull request #402 from opendilab/dev-distar-collector
hiha3456 Jun 27, 2022
61883a7
change commit to run in k8s
hiha3456 Jun 27, 2022
def4799
Merge branch 'dev-distar-collector' into dev-distar
hiha3456 Jun 27, 2022
05cc0f0
fix(zms): fix the bug that when job begin, there is a infinite loop
hiha3456 Jun 27, 2022
8f716f1
Merge branch 'dev-distar-collector' into dev-distar
hiha3456 Jun 27, 2022
d6ef348
update train iter
hiha3456 Jun 28, 2022
0cb987f
change logic of update train_iter
hiha3456 Jun 28, 2022
7012966
add check of main player
hiha3456 Jun 28, 2022
7912613
fix bug
hiha3456 Jun 28, 2022
2b51afc
fix bug
hiha3456 Jun 28, 2022
c912e4b
change structure of map_size from list to point
hiha3456 Jun 28, 2022
83d93d0
test(nyz): add naive distar policy collect test
PaParaZz1 Jun 28, 2022
d96f953
Merge branch 'dev-distar' into dev-distar-collector
hiha3456 Jun 29, 2022
8408399
merge dev-distar-nyz
hiha3456 Jun 29, 2022
7b5b233
to run in k8s
hiha3456 Jun 29, 2022
0b985d2
change num workers
hiha3456 Jun 29, 2022
57d3a43
to run real policy forward_collect
hiha3456 Jun 29, 2022
584fd82
print exception
hiha3456 Jun 30, 2022
df21a9a
reformat test
hiha3456 Jun 30, 2022
72803ef
fix bug
hiha3456 Jun 30, 2022
e8e7551
tools to do serialization and test if two objects same
hiha3456 Jul 6, 2022
2da4010
changes in the model to correctly make actions using pretrained model
hiha3456 Jul 6, 2022
a0efe1c
changes to run the test using pretrained model
hiha3456 Jul 6, 2022
a3c9e39
tests to test the performance againist bot using pretrained mdoel
hiha3456 Jul 6, 2022
43f86a1
changes in the policy(agent) to correctly make actions using pretrain…
hiha3456 Jul 6, 2022
a99c8bb
move GLU and build_activation in action_type_head.py to ding/torch_ut…
hiha3456 Jul 6, 2022
a11e146
change default value of build_activation to False
hiha3456 Jul 10, 2022
ca651b5
add util to change ia's model
hiha3456 Jul 10, 2022
8dd3b39
add update_fake_reward; change behaviour to behavior
hiha3456 Jul 11, 2022
70c1a7f
Merge pull request #411 from opendilab/dev-distar-collector-merge-policy
hiha3456 Jul 11, 2022
b2cfa91
add processss_transition
hiha3456 Jul 11, 2022
25a8a54
load state_dict of teacher model and other debugs
hiha3456 Jul 12, 2022
6c75891
not delete last episode when before append, the first step of newest …
hiha3456 Jul 12, 2022
1f2bc17
insert process transition into rolloutor, fix bug; move self._observa…
hiha3456 Jul 12, 2022
8dfa961
fix bug of calling hamming_distance
hiha3456 Jul 12, 2022
4a01f3d
fix bug when calling levenshtein_distance
hiha3456 Jul 12, 2022
b06d6d8
fix bug of dimension selected_units
hiha3456 Jul 13, 2022
0aedb77
changes to run the whole pipeline from sl_model
hiha3456 Jul 14, 2022
1dfda42
changes to run the winrate test
hiha3456 Jul 19, 2022
3d19436
changes to make whole pipeline running bug freely
hiha3456 Jul 19, 2022
73c0367
Merge branch 'dev-distar-fix-bug' into dev-distar-collector-merge-policy
hiha3456 Jul 19, 2022
4fc95b1
Merge pull request #420 from opendilab/dev-distar-collector-merge-policy
hiha3456 Jul 19, 2022
206b293
feature(zms):updates for distar
hiha3456 Aug 23, 2022
146643e
merge branch 'main' into dev-distar-merge-into-main
hiha3456 Aug 24, 2022
ea11bc1
change comments and delete useless code
hiha3456 Aug 24, 2022
2e60364
move out the distar files into DI-star
Aug 24, 2022
2dc4ab5
move out tensor_dict_to_shm
Aug 24, 2022
2c0e527
add comment of CpuUnpickler
Aug 24, 2022
7a3bc0b
move out distar_test_pipelines
Aug 24, 2022
3f21d3c
change import of collector.py
Aug 24, 2022
37b6bc7
drop out useless mocks
Aug 24, 2022
de8b6eb
make test of coordinator pass
Aug 25, 2022
e63d710
feature(zms): add test_league_learner_communicator.py
Aug 25, 2022
30e6249
change a bit
Aug 25, 2022
9dc8770
change file name
Aug 25, 2022
117ae79
update test_handle_step_exception.py
Aug 25, 2022
f66248d
update test of BattleTransitionList, and add last_step_fn in BattleTr…
Aug 25, 2022
8cf4a82
uupdate tests; actor, collector, functional collector
Aug 26, 2022
9af8168
remove one todo
Aug 26, 2022
6221024
reformat
Aug 26, 2022
16281d9
reformat
Aug 26, 2022
fb83c8e
fix bug
Aug 26, 2022
bc7b477
reformat; add last_step_fn in entry of LeagueActor and BattleStepColl…
Aug 26, 2022
7fc8354
update tests
Aug 26, 2022
3e01213
delete useless files
Sep 2, 2022
0a8be2e
add unittest of flatten and detach_grad
hiha3456 Sep 22, 2022
b3d39f9
add comments about the difference between GLU2 and GLU
hiha3456 Sep 22, 2022
4f09857
add unittest of parameter "dim" in default_collate; remove dim from c…
hiha3456 Sep 23, 2022
54ea00e
reformat
hiha3456 Sep 23, 2022
ef64a7e
usee pytest of test_sparse_logging
hiha3456 Sep 23, 2022
543f1ec
change format of comment of sparse_logging
hiha3456 Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
polish stdim & infonce loss
  • Loading branch information
lixl-st committed May 30, 2022
commit 192e6e4e436ef9d67095ede4bd84fc9be58b0ebf
7 changes: 6 additions & 1 deletion ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from ding.rl_utils import get_epsilon_greedy_fn
from .base_policy import CommandModePolicy

from .dqn import DQNPolicy
from .dqn import DQNPolicy, DQNSTDIMPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
@@ -92,6 +92,11 @@ class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('dqn_stdim_command')
class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('dqfd_command')
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
pass
59 changes: 47 additions & 12 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,14 @@
import torch

from ding.torch_utils import Adam, to_device
from ding.torch_utils.loss.contrastive_loss import ContrastiveLoss
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate

from .base_policy import Policy
from .common_utils import default_preprocess_learn
from ding.torch_utils import ContrastiveLoss


@POLICY_REGISTRY.register('dqn')
@@ -393,7 +394,7 @@ class DQNSTDIMPolicy(DQNPolicy):
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
1 ``type`` str dqn_stdim | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
@@ -437,14 +438,14 @@ class DQNSTDIMPolicy(DQNPolicy):
| decay from start
| value to end value
| during decay length.
20 | ``loss_ratio`` float 0.01 | the ratio of auxiliary loss to main | any real value,
| loss | typically in
20 | ``aux_loss_ratio`` float 0.05 | the ratio of the auxiliary loss to | any real value,
| the TD loss | typically in
| [-0.1, 0.1].
== ==================== ======== ============== ======================================== =======================
"""

config = dict(
type='dqn',
type='dqn_stdim',
# (bool) Whether use cuda in policy
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
@@ -499,18 +500,17 @@ class DQNSTDIMPolicy(DQNPolicy):
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
loss_ratio=0.01,
aux_loss_ratio=0.05,
)

def _init_learn(self) -> None:
super()._init_learn()
self._main_encoder = self._model.encoder
x_size, y_size = self._get_encoding_size()
self._aux_model = ContrastiveLoss(x_size, y_size, **self._cfg.aux_model)
if self._cuda:
self._aux_model.cuda()
self._aux_optimizer = Adam(self._aux_model.parameters(), lr=self._cfg.learn.learning_rate)
self._aux_ratio = self._cfg.loss_ratio
self._aux_ratio = self._cfg.aux_loss_ratio

def _get_encoding_size(self):
obs = self._cfg.model.obs_shape
@@ -526,7 +526,7 @@ def _get_encoding_size(self):

def _aux_encode(self, data):
x = data["obs"]
y = self._main_encoder(data["obs"])
y = self._model.encoder(data["obs"])
return x, y

def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -583,14 +583,14 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)
bellman_loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)

# ======================
# Compute auxiliary loss
# ======================
x, y = self._aux_encode(data)
aux_loss_eval = self._aux_model.forward(x, y) * self._aux_ratio
loss += aux_loss_eval
loss = aux_loss_eval + bellman_loss

# ====================
# Q-learning update
@@ -607,10 +607,45 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'bellman_loss': bellman_loss.item(),
'aux_loss': aux_loss_eval.item(),
'total_loss': loss.item(),
'q_value': q_value.mean().item(),
'priority': td_error_per_sample.abs().tolist(),
# Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
# '[histogram]action_distribution': data['action'],
}

def _monitor_vars_learn(self) -> List[str]:
return ['cur_lr', 'bellman_loss', 'aux_loss', 'total_loss', 'q_value']

def _state_dict_learn(self) -> Dict[str, Any]:
"""
Overview:
Return the state_dict of learn mode, usually including model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
'aux_optimizer': self._aux_optimizer.state_dict(),
}

def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.

.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
self._aux_optimizer.load_state_dict(state_dict['aux_optimizer'])
1 change: 0 additions & 1 deletion ding/torch_utils/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .cross_entropy_loss import LabelSmoothCELoss, SoftFocalLoss, build_ce_criterion
from .multi_logits_loss import MultiLogitsLoss
from .contrastive_loss import ContrastiveLoss
9 changes: 5 additions & 4 deletions ding/torch_utils/loss/tests/test_contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -2,18 +2,19 @@
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from ding.torch_utils import ContrastiveLoss
from ding.torch_utils.loss.contrastive_loss import ContrastiveLoss


@pytest.mark.unittest
@pytest.mark.parametrize('noise', [0.1, 1.0, 3.0])
@pytest.mark.benchmark
@pytest.mark.parametrize('noise', [0.1, 1.0])
@pytest.mark.parametrize('dims', [
[16],
[3, 16, 16]
])
def test_infonce_loss(noise, dims):
print_loss = False
batch_size = 128
N_batch = 10
N_batch = 3
x_dim = [batch_size * N_batch] + dims

encode_shape = 16
5 changes: 3 additions & 2 deletions dizoo/atari/config/serial/pong/pong_dqn_stdim_config.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,8 @@
loss_type = 'infonce',
temperature = 1.0,
),
loss_ratio = 0.05,
# the ratio of the auxiliary loss to the TD loss
aux_loss_ratio = 0.05,
nstep=3,
discount_factor=0.99,
learn=dict(
@@ -54,7 +55,7 @@
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
policy=dict(type='dqn_stdim'),
)
pong_dqn_stdim_create_config = EasyDict(pong_dqn_stdim_create_config)
create_config = pong_dqn_stdim_create_config
Original file line number Diff line number Diff line change
@@ -24,7 +24,8 @@
loss_type = 'infonce',
temperature = 1.0,
),
loss_ratio = 0.05,
# the ratio of the auxiliary loss to the TD loss
aux_loss_ratio = 0.05,
nstep=1,
discount_factor=0.97,
learn=dict(
@@ -52,7 +53,7 @@
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
policy=dict(type='dqn_stdim'),
replay_buffer=dict(
type='deque',
import_names=['ding.data.buffer.deque_buffer_wrapper']