diff --git a/.github/dot.png b/.github/dot.png new file mode 100644 index 0000000..4ea6ab5 Binary files /dev/null and b/.github/dot.png differ diff --git a/README.md b/README.md index 88279ae..513462a 100755 --- a/README.md +++ b/README.md @@ -6,6 +6,37 @@ This repo is (2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679). +(3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf). + + +# DOT: A Distillation-Oriented Trainer + +### Framework + +
+ +### Main Benchmark Results + +On CIFAR-100: + +| Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2| +|:---------------:|:-----------------:|:-----------------:|:-----------------:| +| KD | 73.33 | 72.98 | 74.45 | +| **KD+DOT** | **75.12** | **73.77** | **75.55** | + +On Tiny-ImageNet: + +| Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2| +|:---------------:|:-----------------:|:-----------------:| +| KD | 58.35 | 62.26 | +| **KD+DOT** | **64.01** | **65.75** | + +On ImageNet: + +| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1| +|:---------------:|:-----------------:|:-----------------:| +| KD | 71.03 | 70.50 | +| **KD+DOT** | **71.72** | **73.09** | # Decoupled Knowledge Distillation @@ -170,6 +201,12 @@ If this repo is helpful for your research, please consider citing the paper: journal={arXiv preprint arXiv:2203.08679}, year={2022} } +@article{zhao2023dot, + title={DOT: A Distillation-Oriented Trainer}, + author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun}, + journal={arXiv preprint arXiv:2307.08436}, + year={2023} +} ``` # License diff --git a/configs/cifar100/dot/res32x4_res8x4.yaml b/configs/cifar100/dot/res32x4_res8x4.yaml new file mode 100644 index 0000000..81b406d --- /dev/null +++ b/configs/cifar100/dot/res32x4_res8x4.yaml @@ -0,0 +1,20 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,res32x4,res8x4" + PROJECT: "dot_cifar" +DISTILLER: + TYPE: "KD" + TEACHER: "resnet32x4" + STUDENT: "resnet8x4" +SOLVER: + BATCH_SIZE: 64 + EPOCHS: 240 + LR: 0.05 + LR_DECAY_STAGES: [150, 180, 210] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0005 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.075 diff --git a/configs/cifar100/dot/res32x4_shuv2.yaml b/configs/cifar100/dot/res32x4_shuv2.yaml new file mode 100755 index 0000000..9072871 --- /dev/null +++ b/configs/cifar100/dot/res32x4_shuv2.yaml @@ -0,0 +1,20 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,res32x4,shuv2" + PROJECT: "dot_cifar" +DISTILLER: + TYPE: "KD" + TEACHER: "resnet32x4" + STUDENT: "ShuffleV2" +SOLVER: + BATCH_SIZE: 64 + EPOCHS: 240 + LR: 0.01 + LR_DECAY_STAGES: [150, 180, 210] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0005 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.075 diff --git a/configs/cifar100/dot/vgg13_vgg8.yaml b/configs/cifar100/dot/vgg13_vgg8.yaml new file mode 100755 index 0000000..9a684ef --- /dev/null +++ b/configs/cifar100/dot/vgg13_vgg8.yaml @@ -0,0 +1,20 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,vgg13,vgg8" + PROJECT: "dot_cifar" +DISTILLER: + TYPE: "KD" + TEACHER: "vgg13" + STUDENT: "vgg8" +SOLVER: + BATCH_SIZE: 64 + EPOCHS: 240 + LR: 0.05 + LR_DECAY_STAGES: [150, 180, 210] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0005 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.075 diff --git a/configs/imagenet/r34_r18/dot.yaml b/configs/imagenet/r34_r18/dot.yaml new file mode 100755 index 0000000..8d5e7b3 --- /dev/null +++ b/configs/imagenet/r34_r18/dot.yaml @@ -0,0 +1,33 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,res34,res18" + PROJECT: "dot_imagenet" +DATASET: + TYPE: "imagenet" + NUM_WORKERS: 32 + TEST: + BATCH_SIZE: 128 +DISTILLER: + TYPE: "KD" + TEACHER: "ResNet34" + STUDENT: "ResNet18" +SOLVER: + BATCH_SIZE: 512 + EPOCHS: 100 + LR: 0.2 + LR_DECAY_STAGES: [30, 60, 90] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0001 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.09 +KD: + TEMPERATURE: 1 + LOSS: + CE_WEIGHT: 0.5 + KD_WEIGHT: 0.5 +LOG: + TENSORBOARD_FREQ: 50 + SAVE_CHECKPOINT_FREQ: 10 diff --git a/configs/imagenet/r50_mv1/dot.yaml b/configs/imagenet/r50_mv1/dot.yaml new file mode 100755 index 0000000..a8bd6c7 --- /dev/null +++ b/configs/imagenet/r50_mv1/dot.yaml @@ -0,0 +1,33 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,res50,mobilenetv1" + PROJECT: "dot_imagenet" +DATASET: + TYPE: "imagenet" + NUM_WORKERS: 32 + TEST: + BATCH_SIZE: 128 +DISTILLER: + TYPE: "KD" + TEACHER: "ResNet50" + STUDENT: "MobileNetV1" +SOLVER: + BATCH_SIZE: 512 + EPOCHS: 100 + LR: 0.2 + LR_DECAY_STAGES: [30, 60, 90] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0001 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.09 +KD: + TEMPERATURE: 1 + LOSS: + CE_WEIGHT: 0.5 + KD_WEIGHT: 0.5 +LOG: + TENSORBOARD_FREQ: 50 + SAVE_CHECKPOINT_FREQ: 10 diff --git a/configs/tiny_imagenet/dot/r18_mv2.yaml b/configs/tiny_imagenet/dot/r18_mv2.yaml new file mode 100644 index 0000000..02bebae --- /dev/null +++ b/configs/tiny_imagenet/dot/r18_mv2.yaml @@ -0,0 +1,23 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,r18,mv2" + PROJECT: "dot_tinyimagenet" +DATASET: + TYPE: "tiny_imagenet" + NUM_WORKERS: 16 +DISTILLER: + TYPE: "KD" + TEACHER: "ResNet18" + STUDENT: "MobileNetV2" +SOLVER: + BATCH_SIZE: 256 + EPOCHS: 200 + LR: 0.2 + LR_DECAY_STAGES: [60, 120, 160] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0005 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.075 diff --git a/configs/tiny_imagenet/dot/r18_shuv2.yaml b/configs/tiny_imagenet/dot/r18_shuv2.yaml new file mode 100644 index 0000000..972da3d --- /dev/null +++ b/configs/tiny_imagenet/dot/r18_shuv2.yaml @@ -0,0 +1,23 @@ +EXPERIMENT: + NAME: "" + TAG: "kd,dot,r18,shuv2" + PROJECT: "dot_tinyimagenet" +DATASET: + TYPE: "tiny_imagenet" + NUM_WORKERS: 16 +DISTILLER: + TYPE: "KD" + TEACHER: "ResNet18" + STUDENT: "ShuffleV2" +SOLVER: + BATCH_SIZE: 256 + EPOCHS: 200 + LR: 0.2 + LR_DECAY_STAGES: [60, 120, 160] + LR_DECAY_RATE: 0.1 + WEIGHT_DECAY: 0.0005 + MOMENTUM: 0.9 + TYPE: "SGD" + TRAINER: "dot" + DOT: + DELTA: 0.075 diff --git a/mdistiller/dataset/__init__.py b/mdistiller/dataset/__init__.py index f612437..afaea69 100755 --- a/mdistiller/dataset/__init__.py +++ b/mdistiller/dataset/__init__.py @@ -1,5 +1,6 @@ from .cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample from .imagenet import get_imagenet_dataloaders, get_imagenet_dataloaders_sample +from .tiny_imagenet import get_tinyimagenet_dataloader, get_tinyimagenet_dataloader_sample def get_dataset(cfg): @@ -34,6 +35,21 @@ def get_dataset(cfg): num_workers=cfg.DATASET.NUM_WORKERS, ) num_classes = 1000 + elif cfg.DATASET.TYPE == "tiny_imagenet": + if cfg.DISTILLER.TYPE in ("CRD", "CRDKD"): + train_loader, val_loader, num_data = get_tinyimagenet_dataloader_sample( + batch_size=cfg.SOLVER.BATCH_SIZE, + val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, + num_workers=cfg.DATASET.NUM_WORKERS, + k=cfg.CRD.NCE.K, + ) + else: + train_loader, val_loader, num_data = get_tinyimagenet_dataloader( + batch_size=cfg.SOLVER.BATCH_SIZE, + val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, + num_workers=cfg.DATASET.NUM_WORKERS, + ) + num_classes = 200 else: raise NotImplementedError(cfg.DATASET.TYPE) diff --git a/mdistiller/dataset/tiny_imagenet.py b/mdistiller/dataset/tiny_imagenet.py new file mode 100644 index 0000000..e1b6e7b --- /dev/null +++ b/mdistiller/dataset/tiny_imagenet.py @@ -0,0 +1,122 @@ +import os +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision import transforms +import numpy as np + + +data_folder = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../data/tiny-imagenet-200" +) + + +class ImageFolderInstance(datasets.ImageFolder): + def __getitem__(self, index): + path, target = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + return img, target, index + + +class ImageFolderInstanceSample(ImageFolderInstance): + """: Folder datasets which returns (img, label, index, contrast_index): + """ + def __init__(self, folder, transform=None, target_transform=None, + is_sample=False, k=4096): + super().__init__(folder, transform=transform) + + self.k = k + self.is_sample = is_sample + if self.is_sample: + num_classes = 200 + num_samples = len(self.samples) + label = np.zeros(num_samples, dtype=np.int32) + for i in range(num_samples): + img, target = self.samples[i] + label[i] = target + + self.cls_positive = [[] for i in range(num_classes)] + for i in range(num_samples): + self.cls_positive[label[i]].append(i) + + self.cls_negative = [[] for i in range(num_classes)] + for i in range(num_classes): + for j in range(num_classes): + if j == i: + continue + self.cls_negative[i].extend(self.cls_positive[j]) + + self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] + self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] + print('dataset initialized!') + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + img, target, index = super().__getitem__(index) + + if self.is_sample: + # sample contrastive examples + pos_idx = index + neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) + sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) + return img, target, index, sample_idx + else: + return img, target, index + + +def get_tinyimagenet_dataloader(batch_size, val_batch_size, num_workers): + """Data Loader for tiny-imagenet""" + train_transform = transforms.Compose([ + transforms.RandomRotation(20), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + train_folder = os.path.join(data_folder, "train") + test_folder = os.path.join(data_folder, "val") + train_set = ImageFolderInstance(train_folder, transform=train_transform) + num_data = len(train_set) + test_set = datasets.ImageFolder(test_folder, transform=test_transform) + train_loader = DataLoader( + train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_loader = DataLoader( + test_set, batch_size=val_batch_size, shuffle=False, num_workers=1 + ) + return train_loader, test_loader, num_data + + +def get_tinyimagenet_dataloader_sample(batch_size, val_batch_size, num_workers, k): + """Data Loader for tiny-imagenet""" + train_transform = transforms.Compose([ + transforms.RandomRotation(20), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), + ]) + train_folder = os.path.join(data_folder, "train") + test_folder = os.path.join(data_folder, "val") + train_set = ImageFolderInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k) + num_data = len(train_set) + test_set = datasets.ImageFolder(test_folder, transform=test_transform) + train_loader = DataLoader( + train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_loader = DataLoader( + test_set, batch_size=val_batch_size, shuffle=False, num_workers=1 + ) + return train_loader, test_loader, num_data diff --git a/mdistiller/engine/__init__.py b/mdistiller/engine/__init__.py index d931376..4883802 100755 --- a/mdistiller/engine/__init__.py +++ b/mdistiller/engine/__init__.py @@ -1,6 +1,7 @@ -from .trainer import BaseTrainer, CRDTrainer - +from .trainer import BaseTrainer, CRDTrainer, DOT, CRDDOT trainer_dict = { "base": BaseTrainer, "crd": CRDTrainer, + "dot": DOT, + "crd_dot": CRDDOT, } diff --git a/mdistiller/engine/cfg.py b/mdistiller/engine/cfg.py index e8179fa..41625f0 100755 --- a/mdistiller/engine/cfg.py +++ b/mdistiller/engine/cfg.py @@ -165,3 +165,8 @@ def show_cfg(cfg): CFG.DKD.BETA = 8.0 CFG.DKD.T = 4.0 CFG.DKD.WARMUP = 20 + + +# DOT CFG +CFG.SOLVER.DOT = CN() +CFG.SOLVER.DOT.DELTA = 0.075 diff --git a/mdistiller/engine/dot.py b/mdistiller/engine/dot.py new file mode 100644 index 0000000..1619781 --- /dev/null +++ b/mdistiller/engine/dot.py @@ -0,0 +1,174 @@ +import math +import torch +from torch import Tensor +import torch.optim._functional as F +from torch.optim.optimizer import Optimizer, required +from typing import List, Optional + + +def check_in(t, l): + for i in l: + if t is i: + return True + return False + +def dot(params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + kd_grad_buffer: List[Optional[Tensor]], + kd_momentum_buffer: List[Optional[Tensor]], + kd_params: List[Tensor], + *, + weight_decay: float, + momentum: float, + momentum_kd: float, + lr: float, + dampening: float): + for i, param in enumerate(params): + d_p = d_p_list[i] + if weight_decay != 0: + d_p = d_p.add(param, alpha=weight_decay) + if momentum != 0: + buf = momentum_buffer_list[i] + if buf is None: + buf = torch.clone(d_p).detach() + momentum_buffer_list[i] = buf + elif check_in(param, kd_params): + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + else: + buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening) + d_p = buf + # update params with task-loss grad + param.add_(d_p, alpha=-lr) + + for i, (d_p, buf, p) in enumerate(zip(kd_grad_buffer, kd_momentum_buffer, kd_params)): + # update params with kd-loss grad + if buf is None: + buf = torch.clone(d_p).detach() + kd_momentum_buffer[i] = buf + elif check_in(p, params): + buf.mul_(momentum_kd).add_(d_p, alpha=1 - dampening) + else: + if weight_decay != 0: + d_p = d_p.add(p, alpha=weight_decay) + buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening) + p.add_(buf, alpha=-lr) + + +class DistillationOrientedTrainer(Optimizer): + r""" + Distillation-Oriented Trainer + Usage: + ... + optimizer = DistillationOrientedTrainer() + optimizer.zero_grad(set_to_none=True) + loss_kd.backward(retain_graph=True) + optimizer.step_kd() # get kd-grad and update kd-momentum + optimizer.zero_grad(set_to_none=True) + loss_task.backward() + optimizer.step() # get task-grad and update tast-momentum, then update params. + ... + """ + + def __init__( + self, + params, + lr=required, + momentum=0, + momentum_kd=0, + dampening=0, + weight_decay=0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if momentum_kd < 0.0: + raise ValueError("Invalid momentum kd value: {}".format(momentum_kd)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, momentum=momentum, momentum_kd=momentum_kd, dampening=dampening, + weight_decay=weight_decay) + self.kd_grad_buffer = [] + self.kd_grad_params = [] + self.kd_momentum_buffer = [] + super(DistillationOrientedTrainer, self).__init__(params, defaults) + + @torch.no_grad() + def step_kd(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + assert len(self.param_groups) == 1, "Only implement for one-group params." + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_kd_buffer_list = [] + weight_decay = group['weight_decay'] + dampening = group['dampening'] + lr = group['lr'] + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + state = self.state[p] + if 'momentum_kd_buffer' not in state: + momentum_kd_buffer_list.append(None) + else: + momentum_kd_buffer_list.append(state['momentum_kd_buffer']) + + self.kd_momentum_buffer = momentum_kd_buffer_list + self.kd_grad_buffer = d_p_list + self.kd_grad_params = params_with_grad + return loss + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + assert len(self.param_groups) == 1, "Only implement for one-group params." + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + weight_decay = group['weight_decay'] + momentum = group['momentum'] + momentum_kd = group['momentum_kd'] + dampening = group['dampening'] + lr = group['lr'] + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + + state = self.state[p] + if 'momentum_buffer' not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state['momentum_buffer']) + dot(params_with_grad, + d_p_list, + momentum_buffer_list, + self.kd_grad_buffer, + self.kd_momentum_buffer, + self.kd_grad_params, + weight_decay=weight_decay, + momentum=momentum, + momentum_kd=momentum_kd, + lr=lr, + dampening=dampening) + # update momentum_buffers in state + for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): + state = self.state[p] + state['momentum_buffer'] = momentum_buffer + for p, momentum_kd_buffer in zip(self.kd_grad_params, self.kd_momentum_buffer): + state = self.state[p] + state['momentum_kd_buffer'] = momentum_kd_buffer + self.kd_grad_buffer = [] + self.kd_grad_params = [] + self.kd_momentum_buffer = [] + return loss diff --git a/mdistiller/engine/trainer.py b/mdistiller/engine/trainer.py index 49a5fe4..94d1643 100755 --- a/mdistiller/engine/trainer.py +++ b/mdistiller/engine/trainer.py @@ -16,6 +16,7 @@ load_checkpoint, log_msg, ) +from .dot import DistillationOrientedTrainer class BaseTrainer(object): @@ -223,3 +224,151 @@ def train_iter(self, data, epoch, train_meters): train_meters["top5"].avg, ) return msg + + +class DOT(BaseTrainer): + def init_optimizer(self, cfg): + if cfg.SOLVER.TYPE == "SGD": + m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA + m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA + optimizer = DistillationOrientedTrainer( + self.distiller.module.get_learnable_parameters(), + lr=cfg.SOLVER.LR, + momentum=m_task, + momentum_kd=m_kd, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + ) + else: + raise NotImplementedError(cfg.SOLVER.TYPE) + return optimizer + + def train(self, resume=False): + epoch = 1 + if resume: + state = load_checkpoint(os.path.join(self.log_path, "latest")) + epoch = state["epoch"] + 1 + self.distiller.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.best_acc = state["best_acc"] + while epoch < self.cfg.SOLVER.EPOCHS + 1: + self.train_epoch(epoch) + epoch += 1 + print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL")) + with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer: + writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc))) + + def train_iter(self, data, epoch, train_meters): + train_start_time = time.time() + image, target, index = data + train_meters["data_time"].update(time.time() - train_start_time) + image = image.float() + image = image.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + + # forward + preds, losses_dict = self.distiller(image=image, target=target, epoch=epoch) + + # dot backward + loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean() + self.optimizer.zero_grad(set_to_none=True) + loss_kd.backward(retain_graph=True) + self.optimizer.step_kd() + self.optimizer.zero_grad(set_to_none=True) + loss_ce.backward() + self.optimizer.step() + + train_meters["training_time"].update(time.time() - train_start_time) + # collect info + batch_size = image.size(0) + acc1, acc5 = accuracy(preds, target, topk=(1, 5)) + train_meters["losses"].update((loss_ce + loss_kd).cpu().detach().numpy().mean(), batch_size) + train_meters["top1"].update(acc1[0], batch_size) + train_meters["top5"].update(acc5[0], batch_size) + # print info + msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format( + epoch, + train_meters["data_time"].avg, + train_meters["training_time"].avg, + train_meters["losses"].avg, + train_meters["top1"].avg, + train_meters["top5"].avg, + ) + return msg + + +class CRDDOT(BaseTrainer): + + def init_optimizer(self, cfg): + if cfg.SOLVER.TYPE == "SGD": + m_task = cfg.SOLVER.MOMENTUM - cfg.SOLVER.DOT.DELTA + m_kd = cfg.SOLVER.MOMENTUM + cfg.SOLVER.DOT.DELTA + optimizer = DistillationOrientedTrainer( + self.distiller.module.get_learnable_parameters(), + lr=cfg.SOLVER.LR, + momentum=m_task, + momentum_kd=m_kd, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + ) + else: + raise NotImplementedError(cfg.SOLVER.TYPE) + return optimizer + + def train(self, resume=False): + epoch = 1 + if resume: + state = load_checkpoint(os.path.join(self.log_path, "latest")) + epoch = state["epoch"] + 1 + self.distiller.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.best_acc = state["best_acc"] + while epoch < self.cfg.SOLVER.EPOCHS + 1: + self.train_epoch(epoch) + epoch += 1 + print(log_msg("Best accuracy:{}".format(self.best_acc), "EVAL")) + with open(os.path.join(self.log_path, "worklog.txt"), "a") as writer: + writer.write("best_acc\t" + "{:.2f}".format(float(self.best_acc))) + + def train_iter(self, data, epoch, train_meters): + self.optimizer.zero_grad() + train_start_time = time.time() + image, target, index, contrastive_index = data + train_meters["data_time"].update(time.time() - train_start_time) + image = image.float() + image = image.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + contrastive_index = contrastive_index.cuda(non_blocking=True) + + # forward + preds, losses_dict = self.distiller( + image=image, target=target, index=index, contrastive_index=contrastive_index + ) + + # dot backward + loss_ce, loss_kd = losses_dict['loss_ce'].mean(), losses_dict['loss_kd'].mean() + self.optimizer.zero_grad(set_to_none=True) + loss_kd.backward(retain_graph=True) + self.optimizer.step_kd() + self.optimizer.zero_grad(set_to_none=True) + loss_ce.backward() + # self.optimizer.step((1 - epoch / 240.)) + self.optimizer.step() + + train_meters["training_time"].update(time.time() - train_start_time) + # collect info + batch_size = image.size(0) + acc1, acc5 = accuracy(preds, target, topk=(1, 5)) + train_meters["losses"].update((loss_ce + loss_kd).cpu().detach().numpy().mean(), batch_size) + train_meters["top1"].update(acc1[0], batch_size) + train_meters["top5"].update(acc5[0], batch_size) + # print info + msg = "Epoch:{}| Time(data):{:.3f}| Time(train):{:.3f}| Loss:{:.4f}| Top-1:{:.3f}| Top-5:{:.3f}".format( + epoch, + train_meters["data_time"].avg, + train_meters["training_time"].avg, + train_meters["losses"].avg, + train_meters["top1"].avg, + train_meters["top5"].avg, + ) + return msg diff --git a/mdistiller/models/__init__.py b/mdistiller/models/__init__.py index e8411d7..624b968 100755 --- a/mdistiller/models/__init__.py +++ b/mdistiller/models/__init__.py @@ -1,2 +1,2 @@ -from .cifar import cifar_model_dict +from .cifar import cifar_model_dict, tiny_imagenet_model_dict from .imagenet import imagenet_model_dict diff --git a/mdistiller/models/cifar/__init__.py b/mdistiller/models/cifar/__init__.py index ccf023c..c39916f 100755 --- a/mdistiller/models/cifar/__init__.py +++ b/mdistiller/models/cifar/__init__.py @@ -16,12 +16,19 @@ from .mobilenetv2 import mobile_half from .ShuffleNetv1 import ShuffleV1 from .ShuffleNetv2 import ShuffleV2 +from .mv2_tinyimagenet import mobilenetv2_tinyimagenet cifar100_model_prefix = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../../../download_ckpts/cifar_teachers/" ) + +tiny_imagenet_model_prefix = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../../../download_ckpts/tiny_imagenet_teachers/" +) + cifar_model_dict = { # teachers "resnet56": ( @@ -64,3 +71,10 @@ "ShuffleV1": (ShuffleV1, None), "ShuffleV2": (ShuffleV2, None), } + + +tiny_imagenet_model_dict = { + "ResNet18": (ResNet18, tiny_imagenet_model_prefix + "ResNet18_vanilla/ti_res18"), + "MobileNetV2": (mobilenetv2_tinyimagenet, None), + "ShuffleV2": (ShuffleV2, None), +} diff --git a/mdistiller/models/cifar/mv2_tinyimagenet.py b/mdistiller/models/cifar/mv2_tinyimagenet.py new file mode 100644 index 0000000..4a88944 --- /dev/null +++ b/mdistiller/models/cifar/mv2_tinyimagenet.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LinearBottleNeck(nn.Module): + + def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): + super().__init__() + + self.residual = nn.Sequential( + nn.Conv2d(in_channels, in_channels * t, 1), + nn.BatchNorm2d(in_channels * t), + nn.ReLU6(inplace=True), + + nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), + nn.BatchNorm2d(in_channels * t), + nn.ReLU6(inplace=True), + + nn.Conv2d(in_channels * t, out_channels, 1), + nn.BatchNorm2d(out_channels) + ) + + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x): + + residual = self.residual(x) + + if self.stride == 1 and self.in_channels == self.out_channels: + residual += x + + return residual + +class MobileNetV2(nn.Module): + + def __init__(self, num_classes=100): + super().__init__() + + self.pre = nn.Sequential( + nn.Conv2d(3, 32, 1, padding=1), + nn.BatchNorm2d(32), + nn.ReLU6(inplace=True) + ) + + self.stage1 = LinearBottleNeck(32, 16, 1, 1) + self.stage2 = self._make_stage(2, 16, 24, 2, 6) + self.stage3 = self._make_stage(3, 24, 32, 2, 6) + self.stage4 = self._make_stage(4, 32, 64, 2, 6) + self.stage5 = self._make_stage(3, 64, 96, 1, 6) + self.stage6 = self._make_stage(3, 96, 160, 1, 6) + self.stage7 = LinearBottleNeck(160, 320, 1, 6) + + self.conv1 = nn.Sequential( + nn.Conv2d(320, 1280, 1), + nn.BatchNorm2d(1280), + nn.ReLU6(inplace=True) + ) + + self.conv2 = nn.Conv2d(1280, num_classes, 1) + + def forward(self, x): + x = self.pre(x) + f0 = x + x = self.stage1(x) + x = self.stage2(x) + f1 = x + x = self.stage3(x) + f2 = x + x = self.stage4(x) + f3 = x + x = self.stage5(x) + x = self.stage6(x) + x = self.stage7(x) + x = self.conv1(x) + f4 = x + x = F.adaptive_avg_pool2d(x, 1) + avg = x + x = self.conv2(x) + x = x.view(x.size(0), -1) + feats = {} + feats["feats"] = [f0, f1, f2, f3, f4] + feats["pooled_feat"] = avg + + return x, feats + + def _make_stage(self, repeat, in_channels, out_channels, stride, t): + + layers = [] + layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) + + while repeat - 1: + layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) + repeat -= 1 + + return nn.Sequential(*layers) + +def mobilenetv2_tinyimagenet(**kwargs): + return MobileNetV2(**kwargs) \ No newline at end of file diff --git a/tools/eval.py b/tools/eval.py index 10c2f16..e92580b 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -5,7 +5,7 @@ cudnn.benchmark = True from mdistiller.distillers import Vanilla -from mdistiller.models import cifar_model_dict, imagenet_model_dict +from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict from mdistiller.dataset import get_dataset from mdistiller.dataset.imagenet import get_imagenet_val_loader from mdistiller.engine.utils import load_checkpoint, validate @@ -21,7 +21,7 @@ "--dataset", type=str, default="cifar100", - choices=["cifar100", "imagenet"], + choices=["cifar100", "imagenet", "tiny_imagenet"], ) parser.add_argument("-bs", "--batch-size", type=int, default=64) args = parser.parse_args() @@ -35,9 +35,10 @@ else: model = imagenet_model_dict[args.model](pretrained=False) model.load_state_dict(load_checkpoint(args.ckpt)["model"]) - elif args.dataset == "cifar100": + elif args.dataset in ("cifar100", "tiny_imagenet"): train_loader, val_loader, num_data, num_classes = get_dataset(cfg) - model, pretrain_model_path = cifar_model_dict[args.model] + model_dict = tiny_imagenet_model_dict if args.dataset == "tiny_imagenet" else cifar_model_dict + model, pretrain_model_path = model_dict[args.model] model = model(num_classes=num_classes) ckpt = pretrain_model_path if args.ckpt == "pretrain" else args.ckpt model.load_state_dict(load_checkpoint(ckpt)["model"]) diff --git a/tools/train.py b/tools/train.py index ea41035..de8dbf6 100755 --- a/tools/train.py +++ b/tools/train.py @@ -6,7 +6,7 @@ cudnn.benchmark = True -from mdistiller.models import cifar_model_dict, imagenet_model_dict +from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict from mdistiller.distillers import distiller_dict from mdistiller.dataset import get_dataset from mdistiller.engine.utils import load_checkpoint, log_msg @@ -43,6 +43,8 @@ def main(cfg, resume, opts): if cfg.DISTILLER.TYPE == "NONE": if cfg.DATASET.TYPE == "imagenet": model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False) + elif cfg.DATASET.TYPE == "tiny_imagenet": + model_student = tiny_imagenet_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes) else: model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0]( num_classes=num_classes @@ -55,13 +57,14 @@ def main(cfg, resume, opts): model_teacher = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=True) model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False) else: - net, pretrain_model_path = cifar_model_dict[cfg.DISTILLER.TEACHER] + model_dict = tiny_imagenet_model_dict if cfg.DATASET.TYPE == "tiny_imagenet" else cifar_model_dict + net, pretrain_model_path = model_dict[cfg.DISTILLER.TEACHER] assert ( pretrain_model_path is not None ), "no pretrain model for teacher {}".format(cfg.DISTILLER.TEACHER) model_teacher = net(num_classes=num_classes) model_teacher.load_state_dict(load_checkpoint(pretrain_model_path)["model"]) - model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0]( + model_student = model_dict[cfg.DISTILLER.STUDENT][0]( num_classes=num_classes ) if cfg.DISTILLER.TYPE == "CRD":