-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
10 changed files
with
742 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
examples/quantization_aware_training/coco2017/DETR/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# DETR QAT example | ||
|
||
## preparation | ||
|
||
The `DETR` pretrained model is the checkpoint from https://github.com/facebookresearch/detr . The example will automatically download the checkpoint using `torch.hub.load`. | ||
|
||
The datasets used in this example are train dataset and validation dataset of COCO2017. They can be downloaded from http://cocodataset.org. also the relative cocoapi should be installed. | ||
|
||
## Usage | ||
|
||
```shell | ||
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py qconfig_lsq_8w8f.yaml --coco_path /path/to/coco | ||
``` | ||
|
||
## Metrics | ||
|
||
|DETR-R50|mAPc|AP50|AP75| remarks| | ||
|-|-|-|-|-| | ||
|Float|0.421|0.623|0.443|baseline| | ||
|8w8f| |
278 changes: 278 additions & 0 deletions
278
examples/quantization_aware_training/coco2017/DETR/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
import argparse | ||
import datetime | ||
import json | ||
import random | ||
import time | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
from sparsebit.quantization.modules.conv import QConv2d | ||
import torch | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
|
||
import sys | ||
sys.path.append("./detr") | ||
import detr.util.misc as utils | ||
from detr.datasets import build_dataset, get_coco_api_from_dataset, coco | ||
from detr.engine import evaluate, train_one_epoch | ||
from model import build | ||
|
||
from sparsebit.quantization import QuantModel, parse_qconfig | ||
|
||
def get_args_parser(): | ||
parser = argparse.ArgumentParser('Set transformer detector', add_help=False) | ||
parser.add_argument("qconfig", help="the path of quant config") | ||
parser.add_argument('--lr', default=1e-4, type=float) | ||
parser.add_argument('--lr_backbone', default=1e-5, type=float) | ||
parser.add_argument('--batch_size', default=1, type=int) | ||
parser.add_argument('--weight_decay', default=1e-4, type=float) | ||
parser.add_argument('--epochs', default=300, type=int) | ||
parser.add_argument('--lr_drop', default=200, type=int) | ||
parser.add_argument('--clip_max_norm', default=0.1, type=float, | ||
help='gradient clipping max norm') | ||
|
||
# Model parameters | ||
parser.add_argument('--frozen_weights', type=str, default=None, | ||
help="Path to the pretrained model. If set, only the mask head will be trained") | ||
# * Backbone | ||
parser.add_argument('--backbone', default='resnet50', type=str, | ||
help="Name of the convolutional backbone to use") | ||
parser.add_argument('--dilation', action='store_true', | ||
help="If true, we replace stride with dilation in the last convolutional block (DC5)") | ||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), | ||
help="Type of positional embedding to use on top of the image features") | ||
|
||
# * Transformer | ||
parser.add_argument('--enc_layers', default=6, type=int, | ||
help="Number of encoding layers in the transformer") | ||
parser.add_argument('--dec_layers', default=6, type=int, | ||
help="Number of decoding layers in the transformer") | ||
parser.add_argument('--dim_feedforward', default=2048, type=int, | ||
help="Intermediate size of the feedforward layers in the transformer blocks") | ||
parser.add_argument('--hidden_dim', default=256, type=int, | ||
help="Size of the embeddings (dimension of the transformer)") | ||
parser.add_argument('--dropout', default=0.1, type=float, | ||
help="Dropout applied in the transformer") | ||
parser.add_argument('--nheads', default=8, type=int, | ||
help="Number of attention heads inside the transformer's attentions") | ||
parser.add_argument('--num_queries', default=100, type=int, | ||
help="Number of query slots") | ||
parser.add_argument('--pre_norm', action='store_true') | ||
|
||
# * Segmentation | ||
parser.add_argument('--masks', action='store_true', | ||
help="Train segmentation head if the flag is provided") | ||
|
||
# Loss | ||
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', | ||
help="Disables auxiliary decoding losses (loss at each layer)") | ||
# * Matcher | ||
parser.add_argument('--set_cost_class', default=1, type=float, | ||
help="Class coefficient in the matching cost") | ||
parser.add_argument('--set_cost_bbox', default=5, type=float, | ||
help="L1 box coefficient in the matching cost") | ||
parser.add_argument('--set_cost_giou', default=2, type=float, | ||
help="giou box coefficient in the matching cost") | ||
# * Loss coefficients | ||
parser.add_argument('--mask_loss_coef', default=1, type=float) | ||
parser.add_argument('--dice_loss_coef', default=1, type=float) | ||
parser.add_argument('--bbox_loss_coef', default=5, type=float) | ||
parser.add_argument('--giou_loss_coef', default=2, type=float) | ||
parser.add_argument('--eos_coef', default=0.1, type=float, | ||
help="Relative classification weight of the no-object class") | ||
|
||
# dataset parameters | ||
parser.add_argument('--dataset_file', default='coco') | ||
parser.add_argument('--coco_path', type=str) | ||
parser.add_argument('--coco_panoptic_path', type=str) | ||
parser.add_argument('--remove_difficult', action='store_true') | ||
|
||
parser.add_argument('--output_dir', default='', | ||
help='path where to save, empty for no saving') | ||
parser.add_argument('--device', default='cuda', | ||
help='device to use for training / testing') | ||
parser.add_argument('--seed', default=42, type=int) | ||
parser.add_argument('--resume', default='', help='resume from checkpoint') | ||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', | ||
help='start epoch') | ||
parser.add_argument('--eval', action='store_true') | ||
parser.add_argument('--num_workers', default=2, type=int) | ||
|
||
# distributed training parameters | ||
parser.add_argument('--world_size', default=1, type=int, | ||
help='number of distributed processes') | ||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') | ||
return parser | ||
|
||
|
||
def main(args): | ||
utils.init_distributed_mode(args) | ||
print("git:\n {}\n".format(utils.get_sha())) | ||
|
||
if args.frozen_weights is not None: | ||
assert args.masks, "Frozen training is meant for segmentation only" | ||
print(args) | ||
|
||
device = torch.device(args.device) | ||
|
||
# fix the seed for reproducibility | ||
seed = args.seed + utils.get_rank() | ||
torch.manual_seed(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
|
||
model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True) | ||
model, criterion, postprocessors = build(args, model) | ||
|
||
qconfig = parse_qconfig(args.qconfig) | ||
qmodel = QuantModel(model, config=qconfig).cuda() | ||
|
||
dataset_train = build_dataset(image_set='train', args=args) | ||
dataset_val = build_dataset(image_set='val', args=args) | ||
dataset_calib = build_dataset(image_set='train', args=args) | ||
|
||
if args.distributed: | ||
sampler_train = DistributedSampler(dataset_train) | ||
sampler_val = DistributedSampler(dataset_val, shuffle=False) | ||
else: | ||
sampler_train = torch.utils.data.RandomSampler(dataset_train) | ||
sampler_val = torch.utils.data.SequentialSampler(dataset_val) | ||
|
||
batch_sampler_train = torch.utils.data.BatchSampler( | ||
sampler_train, args.batch_size, drop_last=True) | ||
|
||
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, | ||
collate_fn=utils.collate_fn, num_workers=args.num_workers) | ||
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, | ||
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | ||
data_loader_calib = DataLoader(dataset_calib, args.batch_size, sampler=None, | ||
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | ||
if args.dataset_file == "coco_panoptic": | ||
# We also evaluate AP during panoptic training, on original coco DS | ||
coco_val = coco.build("val", args) | ||
base_ds = get_coco_api_from_dataset(coco_val) | ||
else: | ||
base_ds = get_coco_api_from_dataset(dataset_val) | ||
|
||
# for n, m in model.model.named_modules(): | ||
# if isinstance(m, QConv2d) and "backbone" in n: | ||
# m.input_quantizer.set_bit(bit=4) | ||
# m.weight_quantizer.set_bit(bit=4) | ||
# model.model.backbone_0_body_conv1.input_quantizer.set_bit(bit=8) | ||
# model.model.backbone_0_body_conv1.weight_quantizer.set_bit(bit=8) | ||
|
||
qmodel.prepare_calibration() | ||
calib_size, cur_size = 16, 0 | ||
qmodel.eval() | ||
# model = qmodel._replace_complicated_operators(model).cuda() | ||
with torch.no_grad(): | ||
for samples, _ in data_loader_calib: | ||
# out = model(samples.to(device)) | ||
qmodel(samples.to(device)) | ||
cur_size += args.batch_size | ||
if cur_size >= calib_size: | ||
break | ||
qmodel.init_QAT() | ||
|
||
process_group = torch.distributed.new_group([i for i in range(args.world_size)]) | ||
qmodel_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(qmodel, process_group) | ||
|
||
if args.distributed: | ||
qmodel = torch.nn.parallel.DistributedDataParallel(qmodel, device_ids=[args.gpu]) | ||
qmodel_without_ddp = qmodel.module | ||
n_parameters = sum(p.numel() for p in qmodel.parameters() if p.requires_grad) | ||
print('number of params:', n_parameters) | ||
|
||
param_dicts = [ | ||
{"params": [p for n, p in qmodel_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, | ||
{ | ||
"params": [p for n, p in qmodel_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], | ||
"lr": args.lr_backbone, | ||
}, | ||
] | ||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, | ||
weight_decay=args.weight_decay) | ||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) | ||
|
||
if args.frozen_weights is not None: | ||
checkpoint = torch.load(args.frozen_weights, map_location='cpu') | ||
qmodel_without_ddp.detr.load_state_dict(checkpoint['model']) | ||
|
||
output_dir = Path(args.output_dir) | ||
if args.resume: | ||
if args.resume.startswith('https'): | ||
checkpoint = torch.hub.load_state_dict_from_url( | ||
args.resume, map_location='cpu', check_hash=True) | ||
else: | ||
checkpoint = torch.load(args.resume, map_location='cpu') | ||
qmodel_without_ddp.load_state_dict(checkpoint['model']) | ||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: | ||
optimizer.load_state_dict(checkpoint['optimizer']) | ||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | ||
args.start_epoch = checkpoint['epoch'] + 1 | ||
|
||
if args.eval: | ||
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, | ||
data_loader_val, base_ds, device, args.output_dir) | ||
if args.output_dir: | ||
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") | ||
return | ||
|
||
print("Start training") | ||
start_time = time.time() | ||
for epoch in range(args.start_epoch, args.epochs): | ||
if args.distributed: | ||
sampler_train.set_epoch(epoch) | ||
train_stats = train_one_epoch( | ||
qmodel, criterion, data_loader_train, optimizer, device, epoch, | ||
args.clip_max_norm) | ||
lr_scheduler.step() | ||
if args.output_dir: | ||
checkpoint_paths = [output_dir / 'checkpoint.pth'] | ||
# extra checkpoint before LR drop and every 100 epochs | ||
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: | ||
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') | ||
for checkpoint_path in checkpoint_paths: | ||
utils.save_on_master({ | ||
'model': qmodel_without_ddp.state_dict(), | ||
'optimizer': optimizer.state_dict(), | ||
'lr_scheduler': lr_scheduler.state_dict(), | ||
'epoch': epoch, | ||
'args': args, | ||
}, checkpoint_path) | ||
|
||
test_stats, coco_evaluator = evaluate( | ||
qmodel, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir | ||
) | ||
|
||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, | ||
**{f'test_{k}': v for k, v in test_stats.items()}, | ||
'epoch': epoch, | ||
'n_parameters': n_parameters} | ||
|
||
if args.output_dir and utils.is_main_process(): | ||
with (output_dir / "log.txt").open("a") as f: | ||
f.write(json.dumps(log_stats) + "\n") | ||
|
||
# for evaluation logs | ||
if coco_evaluator is not None: | ||
(output_dir / 'eval').mkdir(exist_ok=True) | ||
if "bbox" in coco_evaluator.coco_eval: | ||
filenames = ['latest.pth'] | ||
if epoch % 50 == 0: | ||
filenames.append(f'{epoch:03}.pth') | ||
for name in filenames: | ||
torch.save(coco_evaluator.coco_eval["bbox"].eval, | ||
output_dir / "eval" / name) | ||
|
||
total_time = time.time() - start_time | ||
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | ||
print('Training time {}'.format(total_time_str)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | ||
args = parser.parse_args() | ||
if args.output_dir: | ||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
main(args) |
Oops, something went wrong.