diff --git a/.github/dot.png b/.github/dot.png
new file mode 100644
index 0000000..4ea6ab5
Binary files /dev/null and b/.github/dot.png differ
diff --git a/README.md b/README.md
index 88279ae..513462a 100755
--- a/README.md
+++ b/README.md
@@ -6,6 +6,37 @@ This repo is
(2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679).
+(3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf).
+
+
+# DOT: A Distillation-Oriented Trainer
+
+### Framework
+
+
+
+### Main Benchmark Results
+
+On CIFAR-100:
+
+| Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2|
+|:---------------:|:-----------------:|:-----------------:|:-----------------:|
+| KD | 73.33 | 72.98 | 74.45 |
+| **KD+DOT** | **75.12** | **73.77** | **75.55** |
+
+On Tiny-ImageNet:
+
+| Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2|
+|:---------------:|:-----------------:|:-----------------:|
+| KD | 58.35 | 62.26 |
+| **KD+DOT** | **64.01** | **65.75** |
+
+On ImageNet:
+
+| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1|
+|:---------------:|:-----------------:|:-----------------:|
+| KD | 71.03 | 70.50 |
+| **KD+DOT** | **71.72** | **73.09** |
# Decoupled Knowledge Distillation
@@ -170,6 +201,12 @@ If this repo is helpful for your research, please consider citing the paper:
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
+@article{zhao2023dot,
+ title={DOT: A Distillation-Oriented Trainer},
+ author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
+ journal={arXiv preprint arXiv:2307.08436},
+ year={2023}
+}
```
# License
diff --git a/configs/cifar100/dot/res32x4_res8x4.yaml b/configs/cifar100/dot/res32x4_res8x4.yaml
new file mode 100644
index 0000000..81b406d
--- /dev/null
+++ b/configs/cifar100/dot/res32x4_res8x4.yaml
@@ -0,0 +1,20 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,res32x4,res8x4"
+ PROJECT: "dot_cifar"
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "resnet32x4"
+ STUDENT: "resnet8x4"
+SOLVER:
+ BATCH_SIZE: 64
+ EPOCHS: 240
+ LR: 0.05
+ LR_DECAY_STAGES: [150, 180, 210]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0005
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.075
diff --git a/configs/cifar100/dot/res32x4_shuv2.yaml b/configs/cifar100/dot/res32x4_shuv2.yaml
new file mode 100755
index 0000000..9072871
--- /dev/null
+++ b/configs/cifar100/dot/res32x4_shuv2.yaml
@@ -0,0 +1,20 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,res32x4,shuv2"
+ PROJECT: "dot_cifar"
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "resnet32x4"
+ STUDENT: "ShuffleV2"
+SOLVER:
+ BATCH_SIZE: 64
+ EPOCHS: 240
+ LR: 0.01
+ LR_DECAY_STAGES: [150, 180, 210]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0005
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.075
diff --git a/configs/cifar100/dot/vgg13_vgg8.yaml b/configs/cifar100/dot/vgg13_vgg8.yaml
new file mode 100755
index 0000000..9a684ef
--- /dev/null
+++ b/configs/cifar100/dot/vgg13_vgg8.yaml
@@ -0,0 +1,20 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,vgg13,vgg8"
+ PROJECT: "dot_cifar"
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "vgg13"
+ STUDENT: "vgg8"
+SOLVER:
+ BATCH_SIZE: 64
+ EPOCHS: 240
+ LR: 0.05
+ LR_DECAY_STAGES: [150, 180, 210]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0005
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.075
diff --git a/configs/imagenet/r34_r18/dot.yaml b/configs/imagenet/r34_r18/dot.yaml
new file mode 100755
index 0000000..8d5e7b3
--- /dev/null
+++ b/configs/imagenet/r34_r18/dot.yaml
@@ -0,0 +1,33 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,res34,res18"
+ PROJECT: "dot_imagenet"
+DATASET:
+ TYPE: "imagenet"
+ NUM_WORKERS: 32
+ TEST:
+ BATCH_SIZE: 128
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "ResNet34"
+ STUDENT: "ResNet18"
+SOLVER:
+ BATCH_SIZE: 512
+ EPOCHS: 100
+ LR: 0.2
+ LR_DECAY_STAGES: [30, 60, 90]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0001
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.09
+KD:
+ TEMPERATURE: 1
+ LOSS:
+ CE_WEIGHT: 0.5
+ KD_WEIGHT: 0.5
+LOG:
+ TENSORBOARD_FREQ: 50
+ SAVE_CHECKPOINT_FREQ: 10
diff --git a/configs/imagenet/r50_mv1/dot.yaml b/configs/imagenet/r50_mv1/dot.yaml
new file mode 100755
index 0000000..a8bd6c7
--- /dev/null
+++ b/configs/imagenet/r50_mv1/dot.yaml
@@ -0,0 +1,33 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,res50,mobilenetv1"
+ PROJECT: "dot_imagenet"
+DATASET:
+ TYPE: "imagenet"
+ NUM_WORKERS: 32
+ TEST:
+ BATCH_SIZE: 128
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "ResNet50"
+ STUDENT: "MobileNetV1"
+SOLVER:
+ BATCH_SIZE: 512
+ EPOCHS: 100
+ LR: 0.2
+ LR_DECAY_STAGES: [30, 60, 90]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0001
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.09
+KD:
+ TEMPERATURE: 1
+ LOSS:
+ CE_WEIGHT: 0.5
+ KD_WEIGHT: 0.5
+LOG:
+ TENSORBOARD_FREQ: 50
+ SAVE_CHECKPOINT_FREQ: 10
diff --git a/configs/tiny_imagenet/dot/r18_mv2.yaml b/configs/tiny_imagenet/dot/r18_mv2.yaml
new file mode 100644
index 0000000..02bebae
--- /dev/null
+++ b/configs/tiny_imagenet/dot/r18_mv2.yaml
@@ -0,0 +1,23 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,r18,mv2"
+ PROJECT: "dot_tinyimagenet"
+DATASET:
+ TYPE: "tiny_imagenet"
+ NUM_WORKERS: 16
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "ResNet18"
+ STUDENT: "MobileNetV2"
+SOLVER:
+ BATCH_SIZE: 256
+ EPOCHS: 200
+ LR: 0.2
+ LR_DECAY_STAGES: [60, 120, 160]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0005
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.075
diff --git a/configs/tiny_imagenet/dot/r18_shuv2.yaml b/configs/tiny_imagenet/dot/r18_shuv2.yaml
new file mode 100644
index 0000000..972da3d
--- /dev/null
+++ b/configs/tiny_imagenet/dot/r18_shuv2.yaml
@@ -0,0 +1,23 @@
+EXPERIMENT:
+ NAME: ""
+ TAG: "kd,dot,r18,shuv2"
+ PROJECT: "dot_tinyimagenet"
+DATASET:
+ TYPE: "tiny_imagenet"
+ NUM_WORKERS: 16
+DISTILLER:
+ TYPE: "KD"
+ TEACHER: "ResNet18"
+ STUDENT: "ShuffleV2"
+SOLVER:
+ BATCH_SIZE: 256
+ EPOCHS: 200
+ LR: 0.2
+ LR_DECAY_STAGES: [60, 120, 160]
+ LR_DECAY_RATE: 0.1
+ WEIGHT_DECAY: 0.0005
+ MOMENTUM: 0.9
+ TYPE: "SGD"
+ TRAINER: "dot"
+ DOT:
+ DELTA: 0.075
diff --git a/mdistiller/dataset/__init__.py b/mdistiller/dataset/__init__.py
index f612437..afaea69 100755
--- a/mdistiller/dataset/__init__.py
+++ b/mdistiller/dataset/__init__.py
@@ -1,5 +1,6 @@
from .cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample
from .imagenet import get_imagenet_dataloaders, get_imagenet_dataloaders_sample
+from .tiny_imagenet import get_tinyimagenet_dataloader, get_tinyimagenet_dataloader_sample
def get_dataset(cfg):
@@ -34,6 +35,21 @@ def get_dataset(cfg):
num_workers=cfg.DATASET.NUM_WORKERS,
)
num_classes = 1000
+ elif cfg.DATASET.TYPE == "tiny_imagenet":
+ if cfg.DISTILLER.TYPE in ("CRD", "CRDKD"):
+ train_loader, val_loader, num_data = get_tinyimagenet_dataloader_sample(
+ batch_size=cfg.SOLVER.BATCH_SIZE,
+ val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
+ num_workers=cfg.DATASET.NUM_WORKERS,
+ k=cfg.CRD.NCE.K,
+ )
+ else:
+ train_loader, val_loader, num_data = get_tinyimagenet_dataloader(
+ batch_size=cfg.SOLVER.BATCH_SIZE,
+ val_batch_size=cfg.DATASET.TEST.BATCH_SIZE,
+ num_workers=cfg.DATASET.NUM_WORKERS,
+ )
+ num_classes = 200
else:
raise NotImplementedError(cfg.DATASET.TYPE)
diff --git a/mdistiller/dataset/tiny_imagenet.py b/mdistiller/dataset/tiny_imagenet.py
new file mode 100644
index 0000000..e1b6e7b
--- /dev/null
+++ b/mdistiller/dataset/tiny_imagenet.py
@@ -0,0 +1,122 @@
+import os
+from torch.utils.data import DataLoader
+from torchvision import datasets
+from torchvision import transforms
+import numpy as np
+
+
+data_folder = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "../../data/tiny-imagenet-200"
+)
+
+
+class ImageFolderInstance(datasets.ImageFolder):
+ def __getitem__(self, index):
+ path, target = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ return img, target, index
+
+
+class ImageFolderInstanceSample(ImageFolderInstance):
+ """: Folder datasets which returns (img, label, index, contrast_index):
+ """
+ def __init__(self, folder, transform=None, target_transform=None,
+ is_sample=False, k=4096):
+ super().__init__(folder, transform=transform)
+
+ self.k = k
+ self.is_sample = is_sample
+ if self.is_sample:
+ num_classes = 200
+ num_samples = len(self.samples)
+ label = np.zeros(num_samples, dtype=np.int32)
+ for i in range(num_samples):
+ img, target = self.samples[i]
+ label[i] = target
+
+ self.cls_positive = [[] for i in range(num_classes)]
+ for i in range(num_samples):
+ self.cls_positive[label[i]].append(i)
+
+ self.cls_negative = [[] for i in range(num_classes)]
+ for i in range(num_classes):
+ for j in range(num_classes):
+ if j == i:
+ continue
+ self.cls_negative[i].extend(self.cls_positive[j])
+
+ self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
+ self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
+ print('dataset initialized!')
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ img, target, index = super().__getitem__(index)
+
+ if self.is_sample:
+ # sample contrastive examples
+ pos_idx = index
+ neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
+ sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
+ return img, target, index, sample_idx
+ else:
+ return img, target, index
+
+
+def get_tinyimagenet_dataloader(batch_size, val_batch_size, num_workers):
+ """Data Loader for tiny-imagenet"""
+ train_transform = transforms.Compose([
+ transforms.RandomRotation(20),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+ test_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+ train_folder = os.path.join(data_folder, "train")
+ test_folder = os.path.join(data_folder, "val")
+ train_set = ImageFolderInstance(train_folder, transform=train_transform)
+ num_data = len(train_set)
+ test_set = datasets.ImageFolder(test_folder, transform=test_transform)
+ train_loader = DataLoader(
+ train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_loader = DataLoader(
+ test_set, batch_size=val_batch_size, shuffle=False, num_workers=1
+ )
+ return train_loader, test_loader, num_data
+
+
+def get_tinyimagenet_dataloader_sample(batch_size, val_batch_size, num_workers, k):
+ """Data Loader for tiny-imagenet"""
+ train_transform = transforms.Compose([
+ transforms.RandomRotation(20),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+ test_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
+ ])
+ train_folder = os.path.join(data_folder, "train")
+ test_folder = os.path.join(data_folder, "val")
+ train_set = ImageFolderInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k)
+ num_data = len(train_set)
+ test_set = datasets.ImageFolder(test_folder, transform=test_transform)
+ train_loader = DataLoader(
+ train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+ test_loader = DataLoader(
+ test_set, batch_size=val_batch_size, shuffle=False, num_workers=1
+ )
+ return train_loader, test_loader, num_data
diff --git a/mdistiller/engine/__init__.py b/mdistiller/engine/__init__.py
index d931376..4883802 100755
--- a/mdistiller/engine/__init__.py
+++ b/mdistiller/engine/__init__.py
@@ -1,6 +1,7 @@
-from .trainer import BaseTrainer, CRDTrainer
-
+from .trainer import BaseTrainer, CRDTrainer, DOT, CRDDOT
trainer_dict = {
"base": BaseTrainer,
"crd": CRDTrainer,
+ "dot": DOT,
+ "crd_dot": CRDDOT,
}
diff --git a/mdistiller/engine/cfg.py b/mdistiller/engine/cfg.py
index e8179fa..41625f0 100755
--- a/mdistiller/engine/cfg.py
+++ b/mdistiller/engine/cfg.py
@@ -165,3 +165,8 @@ def show_cfg(cfg):
CFG.DKD.BETA = 8.0
CFG.DKD.T = 4.0
CFG.DKD.WARMUP = 20
+
+
+# DOT CFG
+CFG.SOLVER.DOT = CN()
+CFG.SOLVER.DOT.DELTA = 0.075
diff --git a/mdistiller/engine/dot.py b/mdistiller/engine/dot.py
new file mode 100644
index 0000000..1619781
--- /dev/null
+++ b/mdistiller/engine/dot.py
@@ -0,0 +1,174 @@
+import math
+import torch
+from torch import Tensor
+import torch.optim._functional as F
+from torch.optim.optimizer import Optimizer, required
+from typing import List, Optional
+
+
+def check_in(t, l):
+ for i in l:
+ if t is i:
+ return True
+ return False
+
+def dot(params: List[Tensor],
+ d_p_list: List[Tensor],
+ momentum_buffer_list: List[Optional[Tensor]],
+ kd_grad_buffer: List[Optional[Tensor]],
+ kd_momentum_buffer: List[Optional[Tensor]],
+ kd_params: List[Tensor],
+ *,
+ weight_decay: float,
+ momentum: float,
+ momentum_kd: float,
+ lr: float,
+ dampening: float):
+ for i, param in enumerate(params):
+ d_p = d_p_list[i]
+ if weight_decay != 0:
+ d_p = d_p.add(param, alpha=weight_decay)
+ if momentum != 0:
+ buf = momentum_buffer_list[i]
+ if buf is None:
+ buf = torch.clone(d_p).detach()
+ momentum_buffer_list[i] = buf
+ elif check_in(param, kd_params):
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+ else:
+ buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening)
+ d_p = buf
+ # update params with task-loss grad
+ param.add_(d_p, alpha=-lr)
+
+ for i, (d_p, buf, p) in enumerate(zip(kd_grad_buffer, kd_momentum_buffer, kd_params)):
+ # update params with kd-loss grad
+ if buf is None:
+ buf = torch.clone(d_p).detach()
+ kd_momentum_buffer[i] = buf
+ elif check_in(p, params):
+ buf.mul_(momentum_kd).add_(d_p, alpha=1 - dampening)
+ else:
+ if weight_decay != 0:
+ d_p = d_p.add(p, alpha=weight_decay)
+ buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening)
+ p.add_(buf, alpha=-lr)
+
+
+class DistillationOrientedTrainer(Optimizer):
+ r"""
+ Distillation-Oriented Trainer
+ Usage:
+ ...
+ optimizer = DistillationOrientedTrainer()
+ optimizer.zero_grad(set_to_none=True)
+ loss_kd.backward(retain_graph=True)
+ optimizer.step_kd() # get kd-grad and update kd-momentum
+ optimizer.zero_grad(set_to_none=True)
+ loss_task.backward()
+ optimizer.step() # get task-grad and update tast-momentum, then update params.
+ ...
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=required,
+ momentum=0,
+ momentum_kd=0,
+ dampening=0,
+ weight_decay=0):
+ if lr is not required and lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if momentum_kd < 0.0:
+ raise ValueError("Invalid momentum kd value: {}".format(momentum_kd))
+ if weight_decay < 0.0:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, momentum=momentum, momentum_kd=momentum_kd, dampening=dampening,
+ weight_decay=weight_decay)
+ self.kd_grad_buffer = []
+ self.kd_grad_params = []
+ self.kd_momentum_buffer = []
+ super(DistillationOrientedTrainer, self).__init__(params, defaults)
+
+ @torch.no_grad()
+ def step_kd(self, closure=None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ assert len(self.param_groups) == 1, "Only implement for one-group params."
+ for group in self.param_groups:
+ params_with_grad = []
+ d_p_list = []
+ momentum_kd_buffer_list = []
+ weight_decay = group['weight_decay']
+ dampening = group['dampening']
+ lr = group['lr']
+ for p in group['params']:
+ if p.grad is not None:
+ params_with_grad.append(p)
+ d_p_list.append(p.grad)
+ state = self.state[p]
+ if 'momentum_kd_buffer' not in state:
+ momentum_kd_buffer_list.append(None)
+ else:
+ momentum_kd_buffer_list.append(state['momentum_kd_buffer'])
+
+ self.kd_momentum_buffer = momentum_kd_buffer_list
+ self.kd_grad_buffer = d_p_list
+ self.kd_grad_params = params_with_grad
+ return loss
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+ assert len(self.param_groups) == 1, "Only implement for one-group params."
+ for group in self.param_groups:
+ params_with_grad = []
+ d_p_list = []
+ momentum_buffer_list = []
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ momentum_kd = group['momentum_kd']
+ dampening = group['dampening']
+ lr = group['lr']
+
+ for p in group['params']:
+ if p.grad is not None:
+ params_with_grad.append(p)
+ d_p_list.append(p.grad)
+
+ state = self.state[p]
+ if 'momentum_buffer' not in state:
+ momentum_buffer_list.append(None)
+ else:
+ momentum_buffer_list.append(state['momentum_buffer'])
+ dot(params_with_grad,
+ d_p_list,
+ momentum_buffer_list,
+ self.kd_grad_buffer,
+ self.kd_momentum_buffer,
+ self.kd_grad_params,
+ weight_decay=weight_decay,
+ momentum=momentum,
+ momentum_kd=momentum_kd,
+ lr=lr,
+ dampening=dampening)
+ # update momentum_buffers in state
+ for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
+ state = self.state[p]
+ state['momentum_buffer'] = momentum_buffer
+ for p, momentum_kd_buffer in zip(self.kd_grad_params, self.kd_momentum_buffer):
+ state = self.state[p]
+ state['momentum_kd_buffer'] = momentum_kd_buffer
+ self.kd_grad_buffer = []
+ self.kd_grad_params = []
+ self.kd_momentum_buffer = []
+ return loss
diff --git a/mdistiller/engine/trainer.py b/mdistiller/engine/trainer.py
index 49a5fe4..94d1643 100755
--- a/mdistiller/engine/trainer.py
+++ b/mdistiller/engine/trainer.py
@@ -16,6 +16,7 @@
load_checkpoint,
log_msg,
)
+from .dot import DistillationOrientedTrainer
class BaseTrainer(object):
@@ -223,3 +224,151 @@ def train_iter(self, data, epoch, train_meters):
train_meters["top5"].avg,
)
return msg
+
+
+class DOT(BaseTrainer):
+ def init_optimizer(self, cfg):
+ if cfg.SOLVER.TYPE == "SGD":
+ m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA
+ m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA
+ optimizer = DistillationOrientedTrainer(
+ self.distiller.module.get_learnable_parameters(),
+ lr=cfg.SOLVER.LR,
+ momentum=m_task,
+ momentum_kd=m_kd,
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
+ )
+ else:
+ raise NotImplementedError(cfg.SOLVER.TYPE)
+ return optimizer
+
+ def train(self, resume=False):
+ epoch = 1
+ if resume:
+ state = load_checkpoint(os.path.join(self.log_path, "latest"))
+ epoch = state["epoch"] + 1
+ self.distiller.load_state_dict(state["model"])
+ self.optimizer.load_state_dict(state["optimizer"])
+ self.best_acc = state["best_acc"]
+ while epoch < self.cfg.SOLVER.EPOCHS + 1:
+ self.train_epoch(epoch)
+ epoch += 1
+ print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
+ with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
+ writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))
+
+ def train_iter(self, data, epoch, train_meters):
+ train_start_time = time.time()
+ image, target, index = data
+ train_meters["data_time"].update(time.time() - train_start_time)
+ image = image.float()
+ image = image.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+ index = index.cuda(non_blocking=True)
+
+ # forward
+ preds, losses_dict = self.distiller(image=image, target=target, epoch=epoch)
+
+ # dot backward
+ loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean()
+ self.optimizer.zero_grad(set_to_none=True)
+ loss_kd.backward(retain_graph=True)
+ self.optimizer.step_kd()
+ self.optimizer.zero_grad(set_to_none=True)
+ loss_ce.backward()
+ self.optimizer.step()
+
+ train_meters["training_time"].update(time.time() - train_start_time)
+ # collect info
+ batch_size = image.size(0)
+ acc1, acc5 = accuracy(preds, target, topk=(1, 5))
+ train_meters["losses"].update((loss_ce + loss_kd).cpu().detach().numpy().mean(), batch_size)
+ train_meters["top1"].update(acc1[0], batch_size)
+ train_meters["top5"].update(acc5[0], batch_size)
+ # print info
+ msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format(
+ epoch,
+ train_meters["data_time"].avg,
+ train_meters["training_time"].avg,
+ train_meters["losses"].avg,
+ train_meters["top1"].avg,
+ train_meters["top5"].avg,
+ )
+ return msg
+
+
+class CRDDOT(BaseTrainer):
+
+ def init_optimizer(self, cfg):
+ if cfg.SOLVER.TYPE == "SGD":
+ m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA
+ m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA
+ optimizer = DistillationOrientedTrainer(
+ self.distiller.module.get_learnable_parameters(),
+ lr=cfg.SOLVER.LR,
+ momentum=m_task,
+ momentum_kd=m_kd,
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
+ )
+ else:
+ raise NotImplementedError(cfg.SOLVER.TYPE)
+ return optimizer
+
+ def train(self, resume=False):
+ epoch = 1
+ if resume:
+ state = load_checkpoint(os.path.join(self.log_path, "latest"))
+ epoch = state["epoch"] + 1
+ self.distiller.load_state_dict(state["model"])
+ self.optimizer.load_state_dict(state["optimizer"])
+ self.best_acc = state["best_acc"]
+ while epoch < self.cfg.SOLVER.EPOCHS + 1:
+ self.train_epoch(epoch)
+ epoch += 1
+ print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL"))
+ with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer:
+ writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc)))
+
+ def train_iter(self, data, epoch, train_meters):
+ self.optimizer.zero_grad()
+ train_start_time = time.time()
+ image, target, index, contrastive_index = data
+ train_meters["data_time"].update(time.time() - train_start_time)
+ image = image.float()
+ image = image.cuda(non_blocking=True)
+ target = target.cuda(non_blocking=True)
+ index = index.cuda(non_blocking=True)
+ contrastive_index = contrastive_index.cuda(non_blocking=True)
+
+ # forward
+ preds, losses_dict = self.distiller(
+ image=image, target=target, index=index, contrastive_index=contrastive_index
+ )
+
+ # dot backward
+ loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean()
+ self.optimizer.zero_grad(set_to_none=True)
+ loss_kd.backward(retain_graph=True)
+ self.optimizer.step_kd()
+ self.optimizer.zero_grad(set_to_none=True)
+ loss_ce.backward()
+ # self.optimizer.step((1 - epoch / 240.))
+ self.optimizer.step()
+
+ train_meters["training_time"].update(time.time() - train_start_time)
+ # collect info
+ batch_size = image.size(0)
+ acc1, acc5 = accuracy(preds, target, topk=(1, 5))
+ train_meters["losses"].update((loss_ce + loss_kd).cpu().detach().numpy().mean(), batch_size)
+ train_meters["top1"].update(acc1[0], batch_size)
+ train_meters["top5"].update(acc5[0], batch_size)
+ # print info
+ msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format(
+ epoch,
+ train_meters["data_time"].avg,
+ train_meters["training_time"].avg,
+ train_meters["losses"].avg,
+ train_meters["top1"].avg,
+ train_meters["top5"].avg,
+ )
+ return msg
diff --git a/mdistiller/models/__init__.py b/mdistiller/models/__init__.py
index e8411d7..624b968 100755
--- a/mdistiller/models/__init__.py
+++ b/mdistiller/models/__init__.py
@@ -1,2 +1,2 @@
-from .cifar import cifar_model_dict
+from .cifar import cifar_model_dict, tiny_imagenet_model_dict
from .imagenet import imagenet_model_dict
diff --git a/mdistiller/models/cifar/__init__.py b/mdistiller/models/cifar/__init__.py
index ccf023c..c39916f 100755
--- a/mdistiller/models/cifar/__init__.py
+++ b/mdistiller/models/cifar/__init__.py
@@ -16,12 +16,19 @@
from .mobilenetv2 import mobile_half
from .ShuffleNetv1 import ShuffleV1
from .ShuffleNetv2 import ShuffleV2
+from .mv2_tinyimagenet import mobilenetv2_tinyimagenet
cifar100_model_prefix = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../../download_ckpts/cifar_teachers/"
)
+
+tiny_imagenet_model_prefix = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ "../../../download_ckpts/tiny_imagenet_teachers/"
+)
+
cifar_model_dict = {
# teachers
"resnet56": (
@@ -64,3 +71,10 @@
"ShuffleV1": (ShuffleV1, None),
"ShuffleV2": (ShuffleV2, None),
}
+
+
+tiny_imagenet_model_dict = {
+ "ResNet18": (ResNet18, tiny_imagenet_model_prefix + "ResNet18_vanilla/ti_res18"),
+ "MobileNetV2": (mobilenetv2_tinyimagenet, None),
+ "ShuffleV2": (ShuffleV2, None),
+}
diff --git a/mdistiller/models/cifar/mv2_tinyimagenet.py b/mdistiller/models/cifar/mv2_tinyimagenet.py
new file mode 100644
index 0000000..4a88944
--- /dev/null
+++ b/mdistiller/models/cifar/mv2_tinyimagenet.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LinearBottleNeck(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride, t=6, class_num=100):
+ super().__init__()
+
+ self.residual = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels * t, 1),
+ nn.BatchNorm2d(in_channels * t),
+ nn.ReLU6(inplace=True),
+
+ nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
+ nn.BatchNorm2d(in_channels * t),
+ nn.ReLU6(inplace=True),
+
+ nn.Conv2d(in_channels * t, out_channels, 1),
+ nn.BatchNorm2d(out_channels)
+ )
+
+ self.stride = stride
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ def forward(self, x):
+
+ residual = self.residual(x)
+
+ if self.stride == 1 and self.in_channels == self.out_channels:
+ residual += x
+
+ return residual
+
+class MobileNetV2(nn.Module):
+
+ def __init__(self, num_classes=100):
+ super().__init__()
+
+ self.pre = nn.Sequential(
+ nn.Conv2d(3, 32, 1, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ReLU6(inplace=True)
+ )
+
+ self.stage1 = LinearBottleNeck(32, 16, 1, 1)
+ self.stage2 = self._make_stage(2, 16, 24, 2, 6)
+ self.stage3 = self._make_stage(3, 24, 32, 2, 6)
+ self.stage4 = self._make_stage(4, 32, 64, 2, 6)
+ self.stage5 = self._make_stage(3, 64, 96, 1, 6)
+ self.stage6 = self._make_stage(3, 96, 160, 1, 6)
+ self.stage7 = LinearBottleNeck(160, 320, 1, 6)
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(320, 1280, 1),
+ nn.BatchNorm2d(1280),
+ nn.ReLU6(inplace=True)
+ )
+
+ self.conv2 = nn.Conv2d(1280, num_classes, 1)
+
+ def forward(self, x):
+ x = self.pre(x)
+ f0 = x
+ x = self.stage1(x)
+ x = self.stage2(x)
+ f1 = x
+ x = self.stage3(x)
+ f2 = x
+ x = self.stage4(x)
+ f3 = x
+ x = self.stage5(x)
+ x = self.stage6(x)
+ x = self.stage7(x)
+ x = self.conv1(x)
+ f4 = x
+ x = F.adaptive_avg_pool2d(x, 1)
+ avg = x
+ x = self.conv2(x)
+ x = x.view(x.size(0), -1)
+ feats = {}
+ feats["feats"] = [f0, f1, f2, f3, f4]
+ feats["pooled_feat"] = avg
+
+ return x, feats
+
+ def _make_stage(self, repeat, in_channels, out_channels, stride, t):
+
+ layers = []
+ layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))
+
+ while repeat - 1:
+ layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
+ repeat -= 1
+
+ return nn.Sequential(*layers)
+
+def mobilenetv2_tinyimagenet(**kwargs):
+ return MobileNetV2(**kwargs)
\ No newline at end of file
diff --git a/tools/eval.py b/tools/eval.py
index 10c2f16..e92580b 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -5,7 +5,7 @@
cudnn.benchmark = True
from mdistiller.distillers import Vanilla
-from mdistiller.models import cifar_model_dict, imagenet_model_dict
+from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict
from mdistiller.dataset import get_dataset
from mdistiller.dataset.imagenet import get_imagenet_val_loader
from mdistiller.engine.utils import load_checkpoint, validate
@@ -21,7 +21,7 @@
"--dataset",
type=str,
default="cifar100",
- choices=["cifar100", "imagenet"],
+ choices=["cifar100", "imagenet", "tiny_imagenet"],
)
parser.add_argument("-bs", "--batch-size", type=int, default=64)
args = parser.parse_args()
@@ -35,9 +35,10 @@
else:
model = imagenet_model_dict[args.model](pretrained=False)
model.load_state_dict(load_checkpoint(args.ckpt)["model"])
- elif args.dataset == "cifar100":
+ elif args.dataset in ("cifar100", "tiny_imagenet"):
train_loader, val_loader, num_data, num_classes = get_dataset(cfg)
- model, pretrain_model_path = cifar_model_dict[args.model]
+ model_dict = tiny_imagenet_model_dict if args.dataset == "tiny_imagenet" else cifar_model_dict
+ model, pretrain_model_path = model_dict[args.model]
model = model(num_classes=num_classes)
ckpt = pretrain_model_path if args.ckpt == "pretrain" else args.ckpt
model.load_state_dict(load_checkpoint(ckpt)["model"])
diff --git a/tools/train.py b/tools/train.py
index ea41035..de8dbf6 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -6,7 +6,7 @@
cudnn.benchmark = True
-from mdistiller.models import cifar_model_dict, imagenet_model_dict
+from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict
from mdistiller.distillers import distiller_dict
from mdistiller.dataset import get_dataset
from mdistiller.engine.utils import load_checkpoint, log_msg
@@ -43,6 +43,8 @@ def main(cfg, resume, opts):
if cfg.DISTILLER.TYPE == "NONE":
if cfg.DATASET.TYPE == "imagenet":
model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False)
+ elif cfg.DATASET.TYPE == "tiny_imagenet":
+ model_student = tiny_imagenet_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes)
else:
model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0](
num_classes=num_classes
@@ -55,13 +57,14 @@ def main(cfg, resume, opts):
model_teacher = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=True)
model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False)
else:
- net, pretrain_model_path = cifar_model_dict[cfg.DISTILLER.TEACHER]
+ model_dict = tiny_imagenet_model_dict if cfg.DATASET.TYPE == "tiny_imagenet" else cifar_model_dict
+ net, pretrain_model_path = model_dict[cfg.DISTILLER.TEACHER]
assert (
pretrain_model_path is not None
), "no pretrain model for teacher {}".format(cfg.DISTILLER.TEACHER)
model_teacher = net(num_classes=num_classes)
model_teacher.load_state_dict(load_checkpoint(pretrain_model_path)["model"])
- model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0](
+ model_student = model_dict[cfg.DISTILLER.STUDENT][0](
num_classes=num_classes
)
if cfg.DISTILLER.TYPE == "CRD":