-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
19 changed files
with
368 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .utils import * | ||
from .ilp import weight_ilp_search, feature_ilp_search | ||
from .metric import metric_factory | ||
from .bit_allocation import * | ||
|
||
|
||
def bit_allocation(qmodel, data): | ||
target_w_bit = qmodel.cfg.SCHEDULE.BIT_ALLOCATION.AVG_WEIGHT_BIT_TARGET | ||
target_a_bit = qmodel.cfg.SCHEDULE.BIT_ALLOCATION.AVG_FEATURE_BIT_TARGET | ||
( | ||
bops_limitation, | ||
bops_limitation_for_feature_search, | ||
memory_limitation, | ||
) = calc_flops_and_limitations(qmodel.model, target_w_bit, target_a_bit) | ||
feature_perturbations, weight_perturbations = metric_factory["greedy"](qmodel, data) | ||
feature_bit_allocated = feature_ilp_search( | ||
qmodel, feature_perturbations, bops_limitation_for_feature_search | ||
) | ||
feature_bit_allocation(qmodel, feature_bit_allocated) | ||
weight_bit_allocated = weight_ilp_search( | ||
qmodel, weight_perturbations, bops_limitation, memory_limitation | ||
) | ||
weight_bit_allocation(qmodel, weight_bit_allocated) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from sparsebit.quantization.modules import QuantOpr | ||
|
||
|
||
def feature_bit_allocation(qmodel, bit_allocated): | ||
for node in qmodel.model.graph.nodes: | ||
if node.target in bit_allocated.keys(): | ||
bit = bit_allocated[node.target] | ||
module = getattr(qmodel.model, node.target) | ||
module.input_quantizer.set_bit(bit) | ||
module.input_quantizer.scale = module.input_quantizer._broadcast_qparams( | ||
module.input_quantizer.observer.scales[bit] | ||
) | ||
module.input_quantizer.zero_point = ( | ||
module.input_quantizer._broadcast_qparams( | ||
module.input_quantizer.observer.zero_points[bit] | ||
) | ||
) | ||
|
||
|
||
def weight_bit_allocation(qmodel, bit_allocated): | ||
for node in qmodel.model.graph.nodes: | ||
if node.target in bit_allocated.keys(): | ||
bit = bit_allocated[node.target] | ||
module = getattr(qmodel.model, node.target) | ||
module.weight_quantizer.set_bit(bit) | ||
module.weight_quantizer.scale = module.weight_quantizer._broadcast_qparams( | ||
module.weight_quantizer.observer.scales[bit] | ||
) | ||
module.weight_quantizer.zero_point = ( | ||
module.weight_quantizer._broadcast_qparams( | ||
module.weight_quantizer.observer.zero_points[bit] | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import pulp | ||
from sparsebit.quantization.modules import QConv2d, QLinear | ||
|
||
|
||
def feature_ilp_search(qmodel, perturbations, bops_limitation): | ||
print("Starting feature ILP search!") | ||
layer_names = list(perturbations.keys()) | ||
layer_modules = [getattr(qmodel.model, name) for name in layer_names] | ||
bit_choices = qmodel.cfg.A.OBSERVER.BIT_CHOICES | ||
|
||
# Define problem | ||
problem = pulp.LpProblem("feature bit allocation", pulp.LpMinimize) | ||
var = pulp.LpVariable.matrix( | ||
"feature", | ||
(range(len(layer_names)), range(len(bit_choices))), | ||
0, | ||
1, | ||
pulp.LpInteger, | ||
) | ||
target_values = [ | ||
perturbations[layer_names[i]][bit_choices[j]] * var[i][j] | ||
for i in range(len(layer_names)) | ||
for j in range(len(bit_choices)) | ||
] | ||
problem += pulp.lpSum(target_values) | ||
|
||
# Set limitations | ||
for i in range( | ||
len(layer_names) | ||
): # only a single bit choice is chosen for each layer | ||
problem += pulp.lpSum(var[i]) == 1 | ||
# add max BOPS limitation | ||
total_bops = [ | ||
layer_modules[i].flops * bit_choices[j] * 8 * var[i][j] | ||
for i in range(len(layer_names)) | ||
for j in range(len(bit_choices)) | ||
] | ||
problem += pulp.lpSum(total_bops) <= bops_limitation + 1 | ||
|
||
# Calculate results | ||
problem.solve(pulp.PULP_CBC_CMD(timeLimit=180)) | ||
bit_allocated = {} | ||
print("Status: " + pulp.LpStatus[problem.status]) | ||
if pulp.LpStatus[problem.status] != "Optimal": | ||
raise ValueError("Integer Linear Programming no solution!") | ||
for v in problem.variables(): | ||
if "__" in v.name: | ||
continue | ||
_, layer_idx, bit_idx = v.name.split("_") | ||
layer_idx = int(layer_idx) | ||
bit_idx = int(bit_idx) | ||
# print(v) | ||
# print(v.varValue) | ||
if v.varValue > 0.5: | ||
bit_allocated[layer_names[layer_idx]] = bit_choices[bit_idx] | ||
print(len(problem.variables())) | ||
print(bit_allocated) | ||
|
||
return bit_allocated | ||
|
||
|
||
def weight_ilp_search(qmodel, perturbations, bops_limitation, memory_limitation): | ||
print("Starting weight ILP search!") | ||
layer_names = list(perturbations.keys()) | ||
layer_modules = [getattr(qmodel.model, name) for name in layer_names] | ||
bit_choices = qmodel.cfg.W.OBSERVER.BIT_CHOICES | ||
|
||
# Define problem | ||
problem = pulp.LpProblem("weight bit allocation", pulp.LpMinimize) | ||
var = pulp.LpVariable.matrix( | ||
"weight", | ||
(range(len(layer_names)), range(len(bit_choices))), | ||
0, | ||
1, | ||
pulp.LpInteger, | ||
) | ||
target_values = [ | ||
perturbations[layer_names[i]][bit_choices[j]] * var[i][j] | ||
for i in range(len(layer_names)) | ||
for j in range(len(bit_choices)) | ||
] | ||
problem += pulp.lpSum(target_values) | ||
|
||
# Set limitations | ||
for i in range( | ||
len(layer_names) | ||
): # only a single bit choice is chosen for each layer | ||
problem += pulp.lpSum(var[i]) == 1 | ||
# add memory limitation | ||
total_memory = [ | ||
layer_modules[i].weight.numel() * bit_choices[j] / 8 * var[i][j] | ||
for i in range(len(layer_names)) | ||
for j in range(len(bit_choices)) | ||
] | ||
problem += pulp.lpSum(total_memory) <= memory_limitation | ||
# add max BOPS limitation | ||
total_bops = [ | ||
layer_modules[i].flops | ||
* layer_modules[i].input_quantizer.bit | ||
* bit_choices[j] | ||
* var[i][j] | ||
for i in range(len(layer_names)) | ||
for j in range(len(bit_choices)) | ||
] | ||
problem += pulp.lpSum(total_bops) <= bops_limitation + 1 | ||
|
||
# Calculate results | ||
problem.solve(pulp.PULP_CBC_CMD(timeLimit=180)) | ||
bit_allocated = {} | ||
print("Status: " + pulp.LpStatus[problem.status]) | ||
if pulp.LpStatus[problem.status] != "Optimal": | ||
raise ValueError("Integer Linear Programming no solution!") | ||
for v in problem.variables(): | ||
if "__" in v.name: | ||
continue | ||
_, layer_idx, bit_idx = v.name.split("_") | ||
layer_idx = int(layer_idx) | ||
bit_idx = int(bit_idx) | ||
# print(v) | ||
# print(v.varValue) | ||
if v.varValue > 0.5: | ||
bit_allocated[layer_names[layer_idx]] = bit_choices[bit_idx] | ||
print(len(problem.variables())) | ||
print(bit_allocated) | ||
|
||
return bit_allocated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .greedy import get_perturbations as get_perturbations_by_greedy | ||
from .hawq import * | ||
|
||
metric_factory = { | ||
"greedy": get_perturbations_by_greedy, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from sparsebit.quantization.modules import QConv2d, QLinear | ||
|
||
|
||
def mse(pred, target, p=2.0): | ||
return (pred - target).abs().pow(p).mean() | ||
|
||
|
||
def get_perturbations(qmodel, data): | ||
qmodel.set_quant(False, False) | ||
float_output = qmodel(data.cuda()) | ||
weight_perturbation = {} | ||
feature_perturbation = {} | ||
for node in qmodel.model.graph.nodes: | ||
if node.op in ["placeholder", "output"]: | ||
continue | ||
module = getattr(qmodel.model, node.target) | ||
if ( | ||
isinstance(module, (QConv2d, QLinear)) | ||
and getattr(module, "input_quantizer", None) | ||
and getattr(module, "weight_quantizer", None) | ||
and not module.input_quantizer.fake_fused | ||
): | ||
print("Layer name:", node.target) | ||
print("FLOPs:", module.flops) | ||
print(" Feature:") | ||
feature_perturbation[node.target] = {} | ||
for bit in module.input_quantizer.observer.scales.keys(): | ||
module.input_quantizer.set_bit(bit) | ||
module.input_quantizer.scale = ( | ||
module.input_quantizer._broadcast_qparams( | ||
module.input_quantizer.observer.scales[bit] | ||
) | ||
) | ||
module.input_quantizer.zero_point = ( | ||
module.input_quantizer._broadcast_qparams( | ||
module.input_quantizer.observer.zero_points[bit] | ||
) | ||
) | ||
module.set_quant(False, True) | ||
quant_output = qmodel(data.cuda()) | ||
perturbation = mse(float_output, quant_output) | ||
module.set_quant(False, False) | ||
print( | ||
" Bit:", str(bit), "Perturbation:", str(perturbation.item()) | ||
) | ||
feature_perturbation[node.target][bit] = perturbation.item() | ||
|
||
print(" Weight:") | ||
weight_perturbation[node.target] = {} | ||
for bit in module.weight_quantizer.observer.scales.keys(): | ||
module.weight_quantizer.set_bit(bit) | ||
module.weight_quantizer.scale = ( | ||
module.weight_quantizer._broadcast_qparams( | ||
module.weight_quantizer.observer.scales[bit] | ||
) | ||
) | ||
module.weight_quantizer.zero_point = ( | ||
module.weight_quantizer._broadcast_qparams( | ||
module.weight_quantizer.observer.zero_points[bit] | ||
) | ||
) | ||
module.set_quant(True, False) | ||
quant_output = qmodel(data.cuda()) | ||
perturbation = mse(float_output, quant_output) | ||
module.set_quant(False, False) | ||
print( | ||
" Bit:", str(bit), "Perturbation:", str(perturbation.item()) | ||
) | ||
weight_perturbation[node.target][bit] = perturbation.item() | ||
|
||
return feature_perturbation, weight_perturbation |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from sparsebit.quantization.modules import QConv2d, QLinear | ||
from sparsebit.quantization.modules.base import QuantOpr | ||
|
||
|
||
def calc_flops_and_limitations(model, target_w_bit, target_a_bit): | ||
bops_limitation = 0 | ||
bops_limitation_for_feature_search = 0 | ||
memory_limitation = 0 | ||
for node in model.graph.nodes: | ||
if node.op in ["placeholder", "output"]: | ||
continue | ||
module = getattr(model, node.target) | ||
if ( | ||
isinstance(module, (QConv2d, QLinear)) | ||
and getattr(module, "input_quantizer", None) | ||
and getattr(module, "weight_quantizer", None) | ||
and not module.input_quantizer.fake_fused | ||
): | ||
module.flops = module.weight.numel() | ||
if isinstance(module, QConv2d): | ||
module.flops *= module.output_hw[0] * module.output_hw[1] | ||
elif ( | ||
isinstance(module, QLinear) | ||
and module.input_quantizer.observer.qdesc._ch_axis == 2 | ||
): | ||
module.flops *= module.seq_len | ||
bops = module.flops * target_w_bit * target_a_bit | ||
bops_limitation_for_feature_search += module.flops * 8 * target_a_bit | ||
bops_limitation += bops | ||
memory_limitation += module.weight.numel() * target_w_bit / 8 # Byte | ||
|
||
return bops_limitation, bops_limitation_for_feature_search, memory_limitation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.