Skip to content

Commit

Permalink
add mixbit for convnet
Browse files Browse the repository at this point in the history
fix bugs

fix bugs

support DeiT

black
  • Loading branch information
Jiang-Stan committed Jan 6, 2023
1 parent 25f852e commit 5835396
Show file tree
Hide file tree
Showing 19 changed files with 368 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ A:
BIT: 8
OBSERVER:
TYPE: MINMAX
LAYOUT: NCHW
LAYOUT: NLC
SPECIFIC: [{
"patch_embed_proj": ["OBSERVER.LAYOUT", "NCHW"],
"head": ["OBSERVER.LAYOUT", "NCHW"],
}]
23 changes: 23 additions & 0 deletions sparsebit/quantization/bit_allocation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .utils import *
from .ilp import weight_ilp_search, feature_ilp_search
from .metric import metric_factory
from .bit_allocation import *


def bit_allocation(qmodel, data):
target_w_bit = qmodel.cfg.SCHEDULE.BIT_ALLOCATION.AVG_WEIGHT_BIT_TARGET
target_a_bit = qmodel.cfg.SCHEDULE.BIT_ALLOCATION.AVG_FEATURE_BIT_TARGET
(
bops_limitation,
bops_limitation_for_feature_search,
memory_limitation,
) = calc_flops_and_limitations(qmodel.model, target_w_bit, target_a_bit)
feature_perturbations, weight_perturbations = metric_factory["greedy"](qmodel, data)
feature_bit_allocated = feature_ilp_search(
qmodel, feature_perturbations, bops_limitation_for_feature_search
)
feature_bit_allocation(qmodel, feature_bit_allocated)
weight_bit_allocated = weight_ilp_search(
qmodel, weight_perturbations, bops_limitation, memory_limitation
)
weight_bit_allocation(qmodel, weight_bit_allocated)
33 changes: 33 additions & 0 deletions sparsebit/quantization/bit_allocation/bit_allocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sparsebit.quantization.modules import QuantOpr


def feature_bit_allocation(qmodel, bit_allocated):
for node in qmodel.model.graph.nodes:
if node.target in bit_allocated.keys():
bit = bit_allocated[node.target]
module = getattr(qmodel.model, node.target)
module.input_quantizer.set_bit(bit)
module.input_quantizer.scale = module.input_quantizer._broadcast_qparams(
module.input_quantizer.observer.scales[bit]
)
module.input_quantizer.zero_point = (
module.input_quantizer._broadcast_qparams(
module.input_quantizer.observer.zero_points[bit]
)
)


def weight_bit_allocation(qmodel, bit_allocated):
for node in qmodel.model.graph.nodes:
if node.target in bit_allocated.keys():
bit = bit_allocated[node.target]
module = getattr(qmodel.model, node.target)
module.weight_quantizer.set_bit(bit)
module.weight_quantizer.scale = module.weight_quantizer._broadcast_qparams(
module.weight_quantizer.observer.scales[bit]
)
module.weight_quantizer.zero_point = (
module.weight_quantizer._broadcast_qparams(
module.weight_quantizer.observer.zero_points[bit]
)
)
126 changes: 126 additions & 0 deletions sparsebit/quantization/bit_allocation/ilp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pulp
from sparsebit.quantization.modules import QConv2d, QLinear


def feature_ilp_search(qmodel, perturbations, bops_limitation):
print("Starting feature ILP search!")
layer_names = list(perturbations.keys())
layer_modules = [getattr(qmodel.model, name) for name in layer_names]
bit_choices = qmodel.cfg.A.OBSERVER.BIT_CHOICES

# Define problem
problem = pulp.LpProblem("feature bit allocation", pulp.LpMinimize)
var = pulp.LpVariable.matrix(
"feature",
(range(len(layer_names)), range(len(bit_choices))),
0,
1,
pulp.LpInteger,
)
target_values = [
perturbations[layer_names[i]][bit_choices[j]] * var[i][j]
for i in range(len(layer_names))
for j in range(len(bit_choices))
]
problem += pulp.lpSum(target_values)

# Set limitations
for i in range(
len(layer_names)
): # only a single bit choice is chosen for each layer
problem += pulp.lpSum(var[i]) == 1
# add max BOPS limitation
total_bops = [
layer_modules[i].flops * bit_choices[j] * 8 * var[i][j]
for i in range(len(layer_names))
for j in range(len(bit_choices))
]
problem += pulp.lpSum(total_bops) <= bops_limitation + 1

# Calculate results
problem.solve(pulp.PULP_CBC_CMD(timeLimit=180))
bit_allocated = {}
print("Status: " + pulp.LpStatus[problem.status])
if pulp.LpStatus[problem.status] != "Optimal":
raise ValueError("Integer Linear Programming no solution!")
for v in problem.variables():
if "__" in v.name:
continue
_, layer_idx, bit_idx = v.name.split("_")
layer_idx = int(layer_idx)
bit_idx = int(bit_idx)
# print(v)
# print(v.varValue)
if v.varValue > 0.5:
bit_allocated[layer_names[layer_idx]] = bit_choices[bit_idx]
print(len(problem.variables()))
print(bit_allocated)

return bit_allocated


def weight_ilp_search(qmodel, perturbations, bops_limitation, memory_limitation):
print("Starting weight ILP search!")
layer_names = list(perturbations.keys())
layer_modules = [getattr(qmodel.model, name) for name in layer_names]
bit_choices = qmodel.cfg.W.OBSERVER.BIT_CHOICES

# Define problem
problem = pulp.LpProblem("weight bit allocation", pulp.LpMinimize)
var = pulp.LpVariable.matrix(
"weight",
(range(len(layer_names)), range(len(bit_choices))),
0,
1,
pulp.LpInteger,
)
target_values = [
perturbations[layer_names[i]][bit_choices[j]] * var[i][j]
for i in range(len(layer_names))
for j in range(len(bit_choices))
]
problem += pulp.lpSum(target_values)

# Set limitations
for i in range(
len(layer_names)
): # only a single bit choice is chosen for each layer
problem += pulp.lpSum(var[i]) == 1
# add memory limitation
total_memory = [
layer_modules[i].weight.numel() * bit_choices[j] / 8 * var[i][j]
for i in range(len(layer_names))
for j in range(len(bit_choices))
]
problem += pulp.lpSum(total_memory) <= memory_limitation
# add max BOPS limitation
total_bops = [
layer_modules[i].flops
* layer_modules[i].input_quantizer.bit
* bit_choices[j]
* var[i][j]
for i in range(len(layer_names))
for j in range(len(bit_choices))
]
problem += pulp.lpSum(total_bops) <= bops_limitation + 1

# Calculate results
problem.solve(pulp.PULP_CBC_CMD(timeLimit=180))
bit_allocated = {}
print("Status: " + pulp.LpStatus[problem.status])
if pulp.LpStatus[problem.status] != "Optimal":
raise ValueError("Integer Linear Programming no solution!")
for v in problem.variables():
if "__" in v.name:
continue
_, layer_idx, bit_idx = v.name.split("_")
layer_idx = int(layer_idx)
bit_idx = int(bit_idx)
# print(v)
# print(v.varValue)
if v.varValue > 0.5:
bit_allocated[layer_names[layer_idx]] = bit_choices[bit_idx]
print(len(problem.variables()))
print(bit_allocated)

return bit_allocated
6 changes: 6 additions & 0 deletions sparsebit/quantization/bit_allocation/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .greedy import get_perturbations as get_perturbations_by_greedy
from .hawq import *

metric_factory = {
"greedy": get_perturbations_by_greedy,
}
71 changes: 71 additions & 0 deletions sparsebit/quantization/bit_allocation/metric/greedy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from sparsebit.quantization.modules import QConv2d, QLinear


def mse(pred, target, p=2.0):
return (pred - target).abs().pow(p).mean()


def get_perturbations(qmodel, data):
qmodel.set_quant(False, False)
float_output = qmodel(data.cuda())
weight_perturbation = {}
feature_perturbation = {}
for node in qmodel.model.graph.nodes:
if node.op in ["placeholder", "output"]:
continue
module = getattr(qmodel.model, node.target)
if (
isinstance(module, (QConv2d, QLinear))
and getattr(module, "input_quantizer", None)
and getattr(module, "weight_quantizer", None)
and not module.input_quantizer.fake_fused
):
print("Layer name:", node.target)
print("FLOPs:", module.flops)
print(" Feature:")
feature_perturbation[node.target] = {}
for bit in module.input_quantizer.observer.scales.keys():
module.input_quantizer.set_bit(bit)
module.input_quantizer.scale = (
module.input_quantizer._broadcast_qparams(
module.input_quantizer.observer.scales[bit]
)
)
module.input_quantizer.zero_point = (
module.input_quantizer._broadcast_qparams(
module.input_quantizer.observer.zero_points[bit]
)
)
module.set_quant(False, True)
quant_output = qmodel(data.cuda())
perturbation = mse(float_output, quant_output)
module.set_quant(False, False)
print(
" Bit:", str(bit), "Perturbation:", str(perturbation.item())
)
feature_perturbation[node.target][bit] = perturbation.item()

print(" Weight:")
weight_perturbation[node.target] = {}
for bit in module.weight_quantizer.observer.scales.keys():
module.weight_quantizer.set_bit(bit)
module.weight_quantizer.scale = (
module.weight_quantizer._broadcast_qparams(
module.weight_quantizer.observer.scales[bit]
)
)
module.weight_quantizer.zero_point = (
module.weight_quantizer._broadcast_qparams(
module.weight_quantizer.observer.zero_points[bit]
)
)
module.set_quant(True, False)
quant_output = qmodel(data.cuda())
perturbation = mse(float_output, quant_output)
module.set_quant(False, False)
print(
" Bit:", str(bit), "Perturbation:", str(perturbation.item())
)
weight_perturbation[node.target][bit] = perturbation.item()

return feature_perturbation, weight_perturbation
Empty file.
32 changes: 32 additions & 0 deletions sparsebit/quantization/bit_allocation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from sparsebit.quantization.modules import QConv2d, QLinear
from sparsebit.quantization.modules.base import QuantOpr


def calc_flops_and_limitations(model, target_w_bit, target_a_bit):
bops_limitation = 0
bops_limitation_for_feature_search = 0
memory_limitation = 0
for node in model.graph.nodes:
if node.op in ["placeholder", "output"]:
continue
module = getattr(model, node.target)
if (
isinstance(module, (QConv2d, QLinear))
and getattr(module, "input_quantizer", None)
and getattr(module, "weight_quantizer", None)
and not module.input_quantizer.fake_fused
):
module.flops = module.weight.numel()
if isinstance(module, QConv2d):
module.flops *= module.output_hw[0] * module.output_hw[1]
elif (
isinstance(module, QLinear)
and module.input_quantizer.observer.qdesc._ch_axis == 2
):
module.flops *= module.seq_len
bops = module.flops * target_w_bit * target_a_bit
bops_limitation_for_feature_search += module.flops * 8 * target_a_bit
bops_limitation += bops
memory_limitation += module.weight.numel() * target_w_bit / 8 # Byte

return bops_limitation, bops_limitation_for_feature_search, memory_limitation
20 changes: 9 additions & 11 deletions sparsebit/quantization/observers/aciq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,26 @@ def __init__(self, config, qdesc):
}
self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5)

def calc_laplace_minmax(self):
def calc_laplace_minmax(self, bit):
if self.is_perchannel:
data = self.data_cache.get_data_for_calibration(Granularity.CHANNELWISE)
b = torch.mean(torch.abs(data - data.mean(1).unsqueeze(1)), dim=1)
else:
data = self.data_cache.get_data_for_calibration(Granularity.LAYERWISE)
b = torch.mean(torch.abs(data - data.mean()))
self.data_cache.reset()
is_half_range = data.min() >= 0
if (
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
and is_half_range
):
max_val = self.alpha_laplace_positive[self.qdesc.bit] * b
max_val = self.alpha_laplace_positive[bit] * b
min_val = torch.zeros(max_val.shape)
else:
max_val = self.alpha_laplace[self.qdesc.bit] * b
max_val = self.alpha_laplace[bit] * b
min_val = -max_val
return min_val, max_val

def calc_gaus_minmax(self):
def calc_gaus_minmax(self, bit):
if self.qdesc.target == QuantTarget.FEATURE:
batch_size = self.data_cache.get_batch_size()
if self.is_perchannel:
Expand All @@ -94,7 +93,6 @@ def calc_gaus_minmax(self):
max_val = data.max()
min_val = data.min()
self.data_cache.get_batch_size
self.data_cache.reset()
is_half_range = data.min() >= 0
num_elements = data.numel()
if self.qdesc.target == QuantTarget.FEATURE:
Expand All @@ -106,18 +104,18 @@ def calc_gaus_minmax(self):
self.qdesc.scheme in [torch.per_channel_affine, torch.per_tensor_affine]
and is_half_range
):
max_val = self.alpha_gaus_positive[self.qdesc.bit] * std
max_val = self.alpha_gaus_positive[bit] * std
min_val = torch.zeros(max_val.shape)
else:
max_val = self.alpha_gaus[self.qdesc.bit] * std
max_val = self.alpha_gaus[bit] * std
min_val = -max_val
return min_val, max_val

def calc_minmax(self):
def calc_minmax(self, bit):
if self.distribution == "laplace":
min_val, max_val = self.calc_laplace_minmax()
min_val, max_val = self.calc_laplace_minmax(bit)
else:
min_val, max_val = self.calc_gaus_minmax()
min_val, max_val = self.calc_gaus_minmax(bit)
self.min_val = min_val.to(self.device)
self.max_val = max_val.to(self.device)

Expand Down
Loading

0 comments on commit 5835396

Please sign in to comment.