Skip to content

Commit

Permalink
add bias correction
Browse files Browse the repository at this point in the history
add bias correction

rectify

inprogressing

resnet18 ok

resnet18 ok

deit ok

deit ok

black

rebase

black
  • Loading branch information
Jiang-Stan committed Jan 3, 2023
1 parent 4219b20 commit fd322c8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ A:
BIT: 8
OBSERVER:
TYPE: MINMAX
LAYOUT: NCHW
LAYOUT: NLC
SPECIFIC: [{
"patch_embed_proj": ["OBSERVER.LAYOUT", "NCHW"],
"head": ["OBSERVER.LAYOUT", "NCHW"],
}]
1 change: 1 addition & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

_C.SCHEDULE = CN()
_C.SCHEDULE.FUSE_BN = False # use ``with torch.no_grad()`` if it's enabled
_C.SCHEDULE.BIAS_CORRECTION = False
_C.SCHEDULE.BN_TUNING = False
_C.SCHEDULE.DISABLE_UNNECESSARY_QUANT = True

Expand Down
4 changes: 3 additions & 1 deletion sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def prepare_calibration(self):
from sparsebit.quantization.tools.calibration import CalibrationRunner

self.eval()
self.calibration_runner = CalibrationRunner(self.model)
self.calibration_runner = CalibrationRunner(
self.model, self.cfg.SCHEDULE.BIAS_CORRECTION
)
self.calibration_runner.prepare_calibration()

def calc_qparams(self, asym=False, w_quant=False, a_quant=False):
Expand Down
50 changes: 49 additions & 1 deletion sparsebit/quantization/tools/calibration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import torch
import torch.nn as nn
from functools import partial

from sparsebit.quantization.modules import QuantOpr
Expand All @@ -9,8 +10,9 @@


class CalibrationRunner(object):
def __init__(self, model):
def __init__(self, model, bias_correction=False):
self.model = fx_symbolic_trace(model)
self.bias_correction = bias_correction

def prepare_calibration(self):
input_names_cache = set(
Expand Down Expand Up @@ -96,6 +98,9 @@ def layerwise_calibration(self, device, asym=False, w_quant=False, a_quant=False
)
self.builder.qstorage.set_output(node.target, quant_outputs)
self.builder.qstorage.finish_node(node.target)
# bias correction
if self.bias_correction:
self.run_bias_correction(batch_num, node, device)
# pop the outputs of nodes whose out-degree=0
self.builder.storage.finish_node(node.target)

Expand Down Expand Up @@ -153,3 +158,46 @@ def module_forward(
if isinstance(module, QuantOpr):
module.set_quant(w_quant=False, a_quant=False)
return outputs

def run_bias_correction(self, batch_num, node, device):
module = getattr(self.model, node.target)
if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None):
for inp_node in node.all_input_nodes:
inp_tensors = self.builder.storage.get_output(inp_node.target)
float_outputs = torch.Tensor([])
quant_outputs = torch.Tensor([])
float_outputs_cached = self.builder.storage.get_output(node.target)
for idx in range(batch_num):
inp_tensor = inp_tensors[idx].cuda()
with torch.no_grad():
float_output = float_outputs_cached[idx]
module.set_quant(True, False)
quant_output = module(inp_tensor).cpu()
module.set_quant(False, False)
float_outputs = torch.cat(
(float_outputs, float_output.detach()), 0
)
quant_outputs = torch.cat(
(quant_outputs, quant_output.detach()), 0
)
float_output_mean = (
float_outputs.transpose(module.input_quantizer.qdesc._ch_axis, 0)
.flatten(1)
.mean(-1)
)
quant_output_mean = (
quant_outputs.transpose(module.input_quantizer.qdesc._ch_axis, 0)
.flatten(1)
.mean(-1)
)
bias = quant_output_mean - float_output_mean
if module.bias is None:
module.bias = nn.Parameter(
data=torch.zeros(
module.weight.size(0),
dtype=torch.float32,
device=device,
),
requires_grad=False,
)
module.bias.data = module.bias.data - bias.cuda()

0 comments on commit fd322c8

Please sign in to comment.