diff --git a/README.md b/README.md index 83bfa9f..faff444 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ If you are willing to use this code or cite the paper, please refer the followin $ pip install -r requirements.txt ``` -## Preprocessing +## Preprocessing ### Cityscapes We expect the original Cityscapes dataset to be located at `data/cityscapes/original`. Please refer to [Cityscapes Dataset](http://www.cityscapes-dataset.net/) and [mcordts/cityscapesScripts](https://github.com/mcordts/cityscapesScripts) for details. ```bash @@ -82,17 +82,17 @@ $ python ./scripts/preprocess_celeba.py \ ### MR-GAN ```bash -$ python main.py --train --mode mr --config ./configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/mr +$ python main.py --mode mr --config ./configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/mr ``` ### Proxy MR-GAN Train a predictor first and determine the checkpoint where the validation loss is minimized. ```bash -$ python main.py --train --mode pred --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/predictor +$ python main.py --mode pred --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/predictor ``` -Use the checkpoint as `--pred-ckpt` to train the generator. +Use the checkpoint as `--pred-ckpt` to train the generator. ```bash -$ python main.py --train --mode mr --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/pmr --pred-ckpt ./logs/predictor/ckpt/{step}-p.pt +$ python main.py --mode mr --config configs/{model}-{dataset}-{distribution}-{method}.yaml --log-dir ./logs/pmr --pred-ckpt ./logs/predictor/ckpt/{step}-p.pt ``` diff --git a/configs/glcic-celeba_128x128-gaussian-pmr2.yaml b/configs/glcic-celeba_128x128-gaussian-pmr2.yaml index 4859f1d..d839ede 100644 --- a/configs/glcic-celeba_128x128-gaussian-pmr2.yaml +++ b/configs/glcic-celeba_128x128-gaussian-pmr2.yaml @@ -41,9 +41,9 @@ model: recon_weight: 100 gan_weight: 1.0 pred: - mm: - mm_1st_weight: 1000. - mm_2nd_weight: 1000. + mr: + mr_1st_weight: 1000. + mr_2nd_weight: 1000. gan_weight: 1.0 mle_weight: 0.0 @@ -52,8 +52,8 @@ g_pretrain_step: 0 d_pretrain_step: 0 batch_size: 16 -num_mm: 8 -num_mm_samples: 12 +num_mr: 8 +num_mr_samples: 12 d_updates_per_step: 1 g_updates_per_step: 3 diff --git a/configs/pix2pix-cityscapes_256x256-gaussian-mr1.yaml b/configs/pix2pix-cityscapes_256x256-gaussian-mr1.yaml index 438ef62..50fc044 100644 --- a/configs/pix2pix-cityscapes_256x256-gaussian-mr1.yaml +++ b/configs/pix2pix-cityscapes_256x256-gaussian-mr1.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-gaussian-mr2.yaml b/configs/pix2pix-cityscapes_256x256-gaussian-mr2.yaml index a366b58..eae858b 100644 --- a/configs/pix2pix-cityscapes_256x256-gaussian-mr2.yaml +++ b/configs/pix2pix-cityscapes_256x256-gaussian-mr2.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-gaussian-pmr1.yaml b/configs/pix2pix-cityscapes_256x256-gaussian-pmr1.yaml index c1bce05..88cdc98 100644 --- a/configs/pix2pix-cityscapes_256x256-gaussian-pmr1.yaml +++ b/configs/pix2pix-cityscapes_256x256-gaussian-pmr1.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 10.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 10.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-gaussian-pmr2.yaml b/configs/pix2pix-cityscapes_256x256-gaussian-pmr2.yaml index 275c6db..1aa66a1 100644 --- a/configs/pix2pix-cityscapes_256x256-gaussian-pmr2.yaml +++ b/configs/pix2pix-cityscapes_256x256-gaussian-pmr2.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 10.0 - mm_2nd_weight: 10.0 + mr: + mr_1st_weight: 10.0 + mr_2nd_weight: 10.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-laplace-mr1.yaml b/configs/pix2pix-cityscapes_256x256-laplace-mr1.yaml index 9fdd7ed..ab17769 100644 --- a/configs/pix2pix-cityscapes_256x256-laplace-mr1.yaml +++ b/configs/pix2pix-cityscapes_256x256-laplace-mr1.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-laplace-mr2.yaml b/configs/pix2pix-cityscapes_256x256-laplace-mr2.yaml index dd23f66..d050bcb 100644 --- a/configs/pix2pix-cityscapes_256x256-laplace-mr2.yaml +++ b/configs/pix2pix-cityscapes_256x256-laplace-mr2.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-laplace-pmr1.yaml b/configs/pix2pix-cityscapes_256x256-laplace-pmr1.yaml index 0080ad6..274a428 100644 --- a/configs/pix2pix-cityscapes_256x256-laplace-pmr1.yaml +++ b/configs/pix2pix-cityscapes_256x256-laplace-pmr1.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 1.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 1.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-cityscapes_256x256-laplace-pmr2.yaml b/configs/pix2pix-cityscapes_256x256-laplace-pmr2.yaml index 147984a..ee18c3b 100644 --- a/configs/pix2pix-cityscapes_256x256-laplace-pmr2.yaml +++ b/configs/pix2pix-cityscapes_256x256-laplace-pmr2.yaml @@ -42,16 +42,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 1.0 - mm_2nd_weight: 1.0 + mr: + mr_1st_weight: 1.0 + mr_2nd_weight: 1.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-gaussian-mr1.yaml b/configs/pix2pix-maps_256x256-gaussian-mr1.yaml index 5dcc01b..041a695 100644 --- a/configs/pix2pix-maps_256x256-gaussian-mr1.yaml +++ b/configs/pix2pix-maps_256x256-gaussian-mr1.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 100.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-gaussian-mr2.yaml b/configs/pix2pix-maps_256x256-gaussian-mr2.yaml index 880d048..885c78b 100644 --- a/configs/pix2pix-maps_256x256-gaussian-mr2.yaml +++ b/configs/pix2pix-maps_256x256-gaussian-mr2.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-gaussian-pmr1.yaml b/configs/pix2pix-maps_256x256-gaussian-pmr1.yaml index 2f95456..e1f812b 100644 --- a/configs/pix2pix-maps_256x256-gaussian-pmr1.yaml +++ b/configs/pix2pix-maps_256x256-gaussian-pmr1.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 300.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 300.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-gaussian-pmr2.yaml b/configs/pix2pix-maps_256x256-gaussian-pmr2.yaml index c5500ac..7b7def3 100644 --- a/configs/pix2pix-maps_256x256-gaussian-pmr2.yaml +++ b/configs/pix2pix-maps_256x256-gaussian-pmr2.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 1000.0 - mm_2nd_weight: 1000.0 + mr: + mr_1st_weight: 1000.0 + mr_2nd_weight: 1000.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-laplace-mr1.yaml b/configs/pix2pix-maps_256x256-laplace-mr1.yaml index a823017..eb7c4c1 100644 --- a/configs/pix2pix-maps_256x256-laplace-mr1.yaml +++ b/configs/pix2pix-maps_256x256-laplace-mr1.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 10.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-laplace-mr2.yaml b/configs/pix2pix-maps_256x256-laplace-mr2.yaml index 8d63b3b..8635de4 100644 --- a/configs/pix2pix-maps_256x256-laplace-mr2.yaml +++ b/configs/pix2pix-maps_256x256-laplace-mr2.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 0.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 0.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 1.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-laplace-pmr1.yaml b/configs/pix2pix-maps_256x256-laplace-pmr1.yaml index 2e7371a..1352339 100644 --- a/configs/pix2pix-maps_256x256-laplace-pmr1.yaml +++ b/configs/pix2pix-maps_256x256-laplace-pmr1.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 10.0 - mm_2nd_weight: 0.0 + mr: + mr_1st_weight: 10.0 + mr_2nd_weight: 0.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/pix2pix-maps_256x256-laplace-pmr2.yaml b/configs/pix2pix-maps_256x256-laplace-pmr2.yaml index b9776f3..b7da634 100644 --- a/configs/pix2pix-maps_256x256-laplace-pmr2.yaml +++ b/configs/pix2pix-maps_256x256-laplace-pmr2.yaml @@ -43,16 +43,16 @@ model: recon_weight: 100. gan_weight: 1. pred: - mm: - mm_1st_weight: 10.0 - mm_2nd_weight: 10.0 + mr: + mr_1st_weight: 10.0 + mr_2nd_weight: 10.0 gan_weight: 1.0 mle_weight: 0.0 recon_weight: 0.0 # this converts optimize_g to BASE mode batch_size: 16 -num_mm: 8 -num_mm_samples: 10 +num_mr: 8 +num_mr_samples: 10 d_updates_per_step: 1 g_updates_per_step: 1 diff --git a/configs/srgan-celeba_64x64-gaussian-mr1.yaml b/configs/srgan-celeba_64x64-gaussian-mr1.yaml index 11c0e53..2176ae9 100644 --- a/configs/srgan-celeba_64x64-gaussian-mr1.yaml +++ b/configs/srgan-celeba_64x64-gaussian-mr1.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 0. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 0. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 20.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 24 +num_mr: 8 +num_mr_samples: 24 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-gaussian-mr2.yaml b/configs/srgan-celeba_64x64-gaussian-mr2.yaml index d407fa6..4f7efc8 100644 --- a/configs/srgan-celeba_64x64-gaussian-mr2.yaml +++ b/configs/srgan-celeba_64x64-gaussian-mr2.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 0. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 0. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 20.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 24 +num_mr: 8 +num_mr_samples: 24 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-gaussian-pmr1.yaml b/configs/srgan-celeba_64x64-gaussian-pmr1.yaml index 3f65d60..1ec9977 100644 --- a/configs/srgan-celeba_64x64-gaussian-pmr1.yaml +++ b/configs/srgan-celeba_64x64-gaussian-pmr1.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 2400. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 2400. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 0.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 12 +num_mr: 8 +num_mr_samples: 12 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-gaussian-pmr2.yaml b/configs/srgan-celeba_64x64-gaussian-pmr2.yaml index 2597df1..42016ab 100644 --- a/configs/srgan-celeba_64x64-gaussian-pmr2.yaml +++ b/configs/srgan-celeba_64x64-gaussian-pmr2.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 2400. - mm_2nd_weight: 2400. + mr: + mr_1st_weight: 2400. + mr_2nd_weight: 2400. gan_weight: 1.0 mle_weight: 0.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 12 +num_mr: 8 +num_mr_samples: 12 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-laplace-mr1.yaml b/configs/srgan-celeba_64x64-laplace-mr1.yaml index e71353a..6317234 100644 --- a/configs/srgan-celeba_64x64-laplace-mr1.yaml +++ b/configs/srgan-celeba_64x64-laplace-mr1.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 0. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 0. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 30.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 24 +num_mr: 8 +num_mr_samples: 24 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-laplace-mr2.yaml b/configs/srgan-celeba_64x64-laplace-mr2.yaml index e77b358..696c0b6 100644 --- a/configs/srgan-celeba_64x64-laplace-mr2.yaml +++ b/configs/srgan-celeba_64x64-laplace-mr2.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 0. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 0. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 30.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 24 +num_mr: 8 +num_mr_samples: 24 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-laplace-pmr1.yaml b/configs/srgan-celeba_64x64-laplace-pmr1.yaml index c018ec4..f40f67d 100644 --- a/configs/srgan-celeba_64x64-laplace-pmr1.yaml +++ b/configs/srgan-celeba_64x64-laplace-pmr1.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 2400. - mm_2nd_weight: 0. + mr: + mr_1st_weight: 2400. + mr_2nd_weight: 0. gan_weight: 1.0 mle_weight: 0.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 12 +num_mr: 8 +num_mr_samples: 12 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/configs/srgan-celeba_64x64-laplace-pmr2.yaml b/configs/srgan-celeba_64x64-laplace-pmr2.yaml index 7f15d63..3889dfd 100644 --- a/configs/srgan-celeba_64x64-laplace-pmr2.yaml +++ b/configs/srgan-celeba_64x64-laplace-pmr2.yaml @@ -27,16 +27,16 @@ model: gan_weight: 1.0 recon_weight: 1000 pred: - mm: - mm_1st_weight: 2400. - mm_2nd_weight: 2400. + mr: + mr_1st_weight: 2400. + mr_2nd_weight: 2400. gan_weight: 1.0 mle_weight: 0.0 batch_size: 32 -num_mm: 8 -num_mm_samples: 12 +num_mr: 8 +num_mr_samples: 12 d_updates_per_step: 1 g_updates_per_step: 5 diff --git a/main.py b/main.py index 88fcdaf..8b432fe 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,7 @@ assert args.mode in MODES, 'Unknown mode %s' % args.mode if args.mode == MODE_MR: if not (args.pred_ckpt or args.resume_ckpt): - print('WARNING: Moment matching mode requires ' + print('WARNING: Proxy MR-GAN requires ' 'checkpoint path of a predictor') # Load config diff --git a/models/base.py b/models/base.py index 00a7ad6..c05635f 100644 --- a/models/base.py +++ b/models/base.py @@ -247,12 +247,12 @@ def sample_statistics(self, samples): if len(samples.size()) == 4: samples = samples.unsqueeze(0) if isinstance(self.mle_loss, GaussianMLELoss): - num_samples = samples.size(1) + num_mr_samples = samples.size(1) sample_1st = samples.mean(dim=1, keepdim=True) # Tensor.std has bug up to PyTorch 0.4.1 sample_2nd = (samples - sample_1st) ** 2 sample_2nd = sample_2nd.sum(dim=1, keepdim=True) - sample_2nd /= num_samples - 1 + sample_2nd /= num_mr_samples - 1 # Laplace statistics elif isinstance(self.mle_loss, LaplaceMLELoss): @@ -273,21 +273,21 @@ def build_d_input(self, x, samples): """Build discriminator input""" raise NotImplementedError - def accumulate_mm_grad(self, x, y, summarize=False): + def accumulate_mr_grad(self, x, y, summarize=False): # Initialize summaries scalar = {} histogram = {} image = {} - num_mm = self.config.num_mm - num_samples = self.config.num_mm_samples + num_mr = self.config.num_mr + num_mr_samples = self.config.num_mr_samples loss = 0. # Get predictive mean and variance if self.net_p is not None: with torch.no_grad(): - pred_1st, pred_log_2nd = self.net_p(x[:num_mm]) + pred_1st, pred_log_2nd = self.net_p(x[:num_mr]) pred_2nd = torch.exp(pred_log_2nd) if summarize: @@ -297,7 +297,7 @@ def accumulate_mm_grad(self, x, y, summarize=False): pred_1st, pred_log_2nd, pred_2nd = None, None, None # Get samples - samples, _ = self.net_g(x[:num_mm], num_samples=num_samples) + samples, _ = self.net_g(x[:num_mr], num_samples=num_mr_samples) # GAN loss if self.loss_config.gan_weight > 0: @@ -311,7 +311,7 @@ def accumulate_mm_grad(self, x, y, summarize=False): # Get sample mean and variance samples = samples.view( - num_mm, num_samples, *list(samples.size()[1:])) + num_mr, num_mr_samples, *list(samples.size()[1:])) sample_1st, sample_2nd = self.sample_statistics(samples) if summarize: @@ -325,22 +325,22 @@ def accumulate_mm_grad(self, x, y, summarize=False): # Direct MLE without predictor if self.loss_config.mle_weight > 0: if self.name == 'glcic': - masks = x[:num_mm, -1:, ...] + masks = x[:num_mr, -1:, ...] sample_2nd = ( sample_2nd * masks + math.exp(self.config.log_dispersion_min) * (1. - masks) ) - normalizer = x[:num_mm, -1:, ...].sum() + normalizer = x[:num_mr, -1:, ...].sum() else: normalizer = None if isinstance(self.mle_loss, GaussianMLELoss): mle_loss = self.mle_loss( - sample_1st, sample_2nd, self.mle_target(y[:num_mm]), + sample_1st, sample_2nd, self.mle_target(y[:num_mr]), log_dispersion=False, normalizer=normalizer) elif isinstance(self.mle_loss, LaplaceMLELoss): sample_mean = samples.mean(1) with torch.no_grad(): - deviation = self.mle_target(y[:num_mm]) - sample_1st + deviation = self.mle_target(y[:num_mr]) - sample_1st mean_target = (deviation + sample_mean).detach() mle_loss = self.mle_loss( sample_mean, sample_2nd, mean_target, @@ -355,35 +355,35 @@ def accumulate_mm_grad(self, x, y, summarize=False): scalar['loss/g/mle'] = mle_loss.detach() # Moment matching - if self.loss_config.mm_1st_weight or self.loss_config.mm_2nd_weight: + if self.loss_config.mr_1st_weight or self.loss_config.mr_2nd_weight: normalizer = ( - x[:num_mm, -1:, ...].sum() if self.name == 'glcic' else None + x[:num_mr, -1:, ...].sum() if self.name == 'glcic' else None ) if isinstance(self.mle_loss, GaussianMLELoss): - mm_1st_loss = self._mse_loss( + mr_1st_loss = self._mse_loss( sample_1st, pred_1st, normalizer=normalizer ) elif isinstance(self.mle_loss, LaplaceMLELoss): sample_mean = samples.mean(1) with torch.no_grad(): mean_target = (pred_1st - sample_1st + sample_mean).detach() - mm_1st_loss = self._mse_loss( + mr_1st_loss = self._mse_loss( sample_mean, mean_target, normalizer=normalizer ) else: raise RuntimeError('Invalid MLE loss') - mm_2nd_loss = self._mse_loss( + mr_2nd_loss = self._mse_loss( sample_2nd, pred_2nd, normalizer=normalizer ) - weighted_mm_loss = \ - self.loss_config.mm_1st_weight * mm_1st_loss + \ - self.loss_config.mm_2nd_weight * mm_2nd_loss - loss += weighted_mm_loss + weighted_mr_loss = \ + self.loss_config.mr_1st_weight * mr_1st_loss + \ + self.loss_config.mr_2nd_weight * mr_2nd_loss + loss += weighted_mr_loss if summarize: - scalar['loss/g/mm_1st'] = mm_1st_loss.detach() - scalar['loss/g/mm_2nd'] = mm_2nd_loss.detach() - scalar['loss/g/mm'] = weighted_mm_loss.detach() + scalar['loss/g/mr_1st'] = mr_1st_loss.detach() + scalar['loss/g/mr_2nd'] = mr_2nd_loss.detach() + scalar['loss/g/mr'] = weighted_mr_loss.detach() if summarize: scalar['loss/g/total'] = loss.detach() diff --git a/models/glcic.py b/models/glcic.py index 37e7e25..8704e94 100644 --- a/models/glcic.py +++ b/models/glcic.py @@ -386,30 +386,30 @@ def optimize_g(self, x, y, step, summarize=False): if summarize: scalar['loss/g/mse'] = mse_loss.detach() scalar['loss/g/total'] += weighted_mse_loss.detach() - # backprop before accumulating mm gradients + # backprop before accumulating mr gradients if isinstance(loss, torch.Tensor): loss.backward() # MR loss (MR) if self.mode == MODE_MR: - mm_summaries = self.accumulate_mm_grad(x, y, summarize) - mm_scalar = mm_summaries['scalar'] - mm_histogram = mm_summaries['histogram'] - mm_image = mm_summaries['image'] + mr_summaries = self.accumulate_mr_grad(x, y, summarize) + mr_scalar = mr_summaries['scalar'] + mr_histogram = mr_summaries['histogram'] + mr_image = mr_summaries['image'] torch.cuda.empty_cache() if summarize: - scalar['loss/g/total'] += mm_scalar['loss/g/total'] - del mm_scalar['loss/g/total'] - scalar.update(mm_scalar) - histogram.update(mm_histogram) - for i in range(min(16, self.config.num_mm)): + scalar['loss/g/total'] += mr_scalar['loss/g/total'] + del mr_scalar['loss/g/total'] + scalar.update(mr_scalar) + histogram.update(mr_histogram) + for i in range(min(16, self.config.num_mr)): image_id = 'train_samples/%d' % i if image_id in image: image[image_id] = torch.cat([ - image[image_id], mm_image[image_id] + image[image_id], mr_image[image_id] ], 2) else: - image[image_id] = mm_image[image_id] + image[image_id] = mr_image[image_id] image[image_id] = image[image_id] # Optimize the network self.clip_grad(self.optim_g, self.config.g_optimizer.clip_grad) @@ -558,17 +558,17 @@ def mle_target(self, y): return y def build_d_input(self, x, samples): - num_mm = self.config.num_mm - num_samples = self.config.num_mm_samples - local_box_dup = x.local_boxes[:num_mm].unsqueeze(1) + num_mr = self.config.num_mr + num_samples = self.config.num_mr_samples + local_box_dup = x.local_boxes[:num_mr].unsqueeze(1) local_box_dup = local_box_dup.expand(-1, num_samples, -1) local_box_dup = local_box_dup.contiguous().view( - num_mm * num_samples, local_box_dup.size(2) + num_mr * num_samples, local_box_dup.size(2) ) - x_dup = x[:num_mm].unsqueeze(1).expand(-1, num_samples, -1, -1, -1) + x_dup = x[:num_mr].unsqueeze(1).expand(-1, num_samples, -1, -1, -1) x_dup = x_dup.contiguous().view( - num_mm * num_samples, *list(x_dup.size()[2:])) + num_mr * num_samples, *list(x_dup.size()[2:])) x_dup.local_boxes = local_box_dup diff --git a/models/pix2pix.py b/models/pix2pix.py index 68795e6..a95c5a2 100644 --- a/models/pix2pix.py +++ b/models/pix2pix.py @@ -84,13 +84,13 @@ def mle_target(self, y): return y def build_d_input(self, x, samples): - num_mm = self.config.num_mm - num_samples = self.config.num_mm_samples + num_mr = self.config.num_mr + num_samples = self.config.num_mr_samples - x_dup = x[:num_mm].unsqueeze(1) + x_dup = x[:num_mr].unsqueeze(1) x_dup = x_dup.expand(-1, num_samples, -1, -1, -1) x_dup = x_dup.contiguous().view( - num_mm * num_samples, *list(x_dup.size()[2:])) + num_mr * num_samples, *list(x_dup.size()[2:])) return torch.cat([x_dup, samples], 1) def optimize_g(self, x, y, step, summarize=False): @@ -128,29 +128,29 @@ def optimize_g(self, x, y, step, summarize=False): # Moment matching loss if self.mode == MODE_MR and ( - self.loss_config.mm_1st_weight > 0 - or self.loss_config.mm_2nd_weight > 0 + self.loss_config.mr_1st_weight > 0 + or self.loss_config.mr_2nd_weight > 0 or self.loss_config.mle_weight > 0 ): - mm_summaries = self.accumulate_mm_grad(x, y, summarize) - mm_scalar = mm_summaries['scalar'] - mm_histogram = mm_summaries['histogram'] - mm_image = mm_summaries['image'] + mr_summaries = self.accumulate_mr_grad(x, y, summarize) + mr_scalar = mr_summaries['scalar'] + mr_histogram = mr_summaries['histogram'] + mr_image = mr_summaries['image'] torch.cuda.empty_cache() if summarize: - scalar['loss/g/total'] += mm_scalar['loss/g/total'] - del mm_scalar['loss/g/total'] - scalar.update(mm_scalar) - histogram.update(mm_histogram) - for i in range(min(16, self.config.num_mm)): + scalar['loss/g/total'] += mr_scalar['loss/g/total'] + del mr_scalar['loss/g/total'] + scalar.update(mr_scalar) + histogram.update(mr_histogram) + for i in range(min(16, self.config.num_mr)): image_id = 'train_samples/%d' % i if image_id in image: image[image_id] = torch.cat([ - image[image_id], mm_image[image_id] + image[image_id], mr_image[image_id] ], 2) else: - image[image_id] = mm_image[image_id] + image[image_id] = mr_image[image_id] image[image_id] = image[image_id] * 0.5 + 0.5 # Optimize @@ -382,7 +382,7 @@ def __init__(self, mode, config): # Create up conv in reverse order prev_ch = self.out_ch next_ch = self.num_features - self.mm_ch, self.up_ch = [], [] + self.mr_ch, self.up_ch = [], [] for i in range(self.num_downs): if self.mode == MODE_MR: noise_dim = self.noise_dim[i + 1] diff --git a/models/srgan.py b/models/srgan.py index be08559..a86f0c5 100644 --- a/models/srgan.py +++ b/models/srgan.py @@ -263,29 +263,29 @@ def optimize_g(self, x, y, step, summarize=False): if summarize: scalar['loss/g/mse'] = mse_loss.detach() scalar['loss/g/total'] += weighted_mse_loss.detach() - # backprop before accumulating mm gradients + # backprop before accumulating mr gradients if isinstance(loss, torch.Tensor): loss.backward() # MR loss (MR) if self.mode == MODE_MR: - mm_summaries = self.accumulate_mm_grad(x, y, summarize) - mm_scalar = mm_summaries['scalar'] - mm_histogram = mm_summaries['histogram'] - mm_image = mm_summaries['image'] + mr_summaries = self.accumulate_mr_grad(x, y, summarize) + mr_scalar = mr_summaries['scalar'] + mr_histogram = mr_summaries['histogram'] + mr_image = mr_summaries['image'] torch.cuda.empty_cache() if summarize: - scalar['loss/g/total'] += mm_scalar['loss/g/total'] - del mm_scalar['loss/g/total'] - scalar.update(mm_scalar) - histogram.update(mm_histogram) - for i in range(min(16, self.config.num_mm)): + scalar['loss/g/total'] += mr_scalar['loss/g/total'] + del mr_scalar['loss/g/total'] + scalar.update(mr_scalar) + histogram.update(mr_histogram) + for i in range(min(16, self.config.num_mr)): image_id = 'train_samples/%d' % i if image_id in image: image[image_id] = torch.cat([ - image[image_id], mm_image[image_id] + image[image_id], mr_image[image_id] ], 2) else: - image[image_id] = mm_image[image_id] + image[image_id] = mr_image[image_id] image[image_id] = (image[image_id] + 1) / 2 # Optimize the network self.clip_grad(self.optim_g, self.config.g_optimizer.clip_grad)