Skip to content

Commit

Permalink
detr qat
Browse files Browse the repository at this point in the history
debugging

support detr qat

remove print
  • Loading branch information
Jiang-Stan committed Dec 7, 2022
1 parent 3aec099 commit d1db231
Show file tree
Hide file tree
Showing 10 changed files with 742 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "examples/post_training_quantization/coco2017/DETR/detr"]
path = examples/post_training_quantization/coco2017/DETR/detr
url = https://github.com/facebookresearch/detr.git
[submodule "examples/quantization_aware_training/coco2017/DETR/detr"]
path = examples/quantization_aware_training/coco2017/DETR/detr
url = https://github.com/facebookresearch/detr.git
20 changes: 20 additions & 0 deletions examples/quantization_aware_training/coco2017/DETR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# DETR QAT example

## preparation

The `DETR` pretrained model is the checkpoint from https://github.com/facebookresearch/detr . The example will automatically download the checkpoint using `torch.hub.load`.

The datasets used in this example are train dataset and validation dataset of COCO2017. They can be downloaded from http://cocodataset.org. also the relative cocoapi should be installed.

## Usage

```shell
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py qconfig_lsq_8w8f.yaml --coco_path /path/to/coco
```

## Metrics

|DETR-R50|mAPc|AP50|AP75| remarks|
|-|-|-|-|-|
|Float|0.421|0.623|0.443|baseline|
|8w8f|
1 change: 1 addition & 0 deletions examples/quantization_aware_training/coco2017/DETR/detr
Submodule detr added at 8a144f
278 changes: 278 additions & 0 deletions examples/quantization_aware_training/coco2017/DETR/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import argparse
import datetime
import json
import random
import time
from pathlib import Path

import numpy as np
from sparsebit.quantization.modules.conv import QConv2d
import torch
from torch.utils.data import DataLoader, DistributedSampler

import sys
sys.path.append("./detr")
import detr.util.misc as utils
from detr.datasets import build_dataset, get_coco_api_from_dataset, coco
from detr.engine import evaluate, train_one_epoch
from model import build

from sparsebit.quantization import QuantModel, parse_qconfig

def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument("qconfig", help="the path of quant config")
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')

# Model parameters
parser.add_argument('--frozen_weights', type=str, default=None,
help="Path to the pretrained model. If set, only the mask head will be trained")
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")

# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')

# * Segmentation
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")

# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
help="Disables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
help="giou box coefficient in the matching cost")
# * Loss coefficients
parser.add_argument('--mask_loss_coef', default=1, type=float)
parser.add_argument('--dice_loss_coef', default=1, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
help="Relative classification weight of the no-object class")

# dataset parameters
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--coco_path', type=str)
parser.add_argument('--coco_panoptic_path', type=str)
parser.add_argument('--remove_difficult', action='store_true')

parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--num_workers', default=2, type=int)

# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser


def main(args):
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))

if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)

device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
model, criterion, postprocessors = build(args, model)

qconfig = parse_qconfig(args.qconfig)
qmodel = QuantModel(model, config=qconfig).cuda()

dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
dataset_calib = build_dataset(image_set='train', args=args)

if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)

data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_calib = DataLoader(dataset_calib, args.batch_size, sampler=None,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
if args.dataset_file == "coco_panoptic":
# We also evaluate AP during panoptic training, on original coco DS
coco_val = coco.build("val", args)
base_ds = get_coco_api_from_dataset(coco_val)
else:
base_ds = get_coco_api_from_dataset(dataset_val)

# for n, m in model.model.named_modules():
# if isinstance(m, QConv2d) and "backbone" in n:
# m.input_quantizer.set_bit(bit=4)
# m.weight_quantizer.set_bit(bit=4)
# model.model.backbone_0_body_conv1.input_quantizer.set_bit(bit=8)
# model.model.backbone_0_body_conv1.weight_quantizer.set_bit(bit=8)

qmodel.prepare_calibration()
calib_size, cur_size = 16, 0
qmodel.eval()
# model = qmodel._replace_complicated_operators(model).cuda()
with torch.no_grad():
for samples, _ in data_loader_calib:
# out = model(samples.to(device))
qmodel(samples.to(device))
cur_size += args.batch_size
if cur_size >= calib_size:
break
qmodel.init_QAT()

process_group = torch.distributed.new_group([i for i in range(args.world_size)])
qmodel_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(qmodel, process_group)

if args.distributed:
qmodel = torch.nn.parallel.DistributedDataParallel(qmodel, device_ids=[args.gpu])
qmodel_without_ddp = qmodel.module
n_parameters = sum(p.numel() for p in qmodel.parameters() if p.requires_grad)
print('number of params:', n_parameters)

param_dicts = [
{"params": [p for n, p in qmodel_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in qmodel_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
qmodel_without_ddp.detr.load_state_dict(checkpoint['model'])

output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
qmodel_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

if args.eval:
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return

print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
qmodel, criterion, data_loader_train, optimizer, device, epoch,
args.clip_max_norm)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': qmodel_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)

test_stats, coco_evaluator = evaluate(
qmodel, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)

log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}

if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")

# for evaluation logs
if coco_evaluator is not None:
(output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
Loading

0 comments on commit d1db231

Please sign in to comment.