-
Notifications
You must be signed in to change notification settings - Fork 22
/
stackmapcompress.py
375 lines (326 loc) · 11.4 KB
/
stackmapcompress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# -*- indent-tabs-mode: nil -*-
# Parse output of "go build -gcflags=all=-S -a cmd/go >& /tmp/go.s" and
# compress register liveness maps in various ways.
import re
import sys
import collections
if True:
# Register maps
FUNCDATA = "3"
PCDATA = "2"
else:
# Stack maps
FUNCDATA = "1" # Locals (not args)
PCDATA = "0"
class Stackmap:
def __init__(self, dec=None):
if dec is None:
self.n = self.nbit = 0
self.bitmaps = []
else:
# Decode Go encoding of a runtime.stackmap.
n = dec.int32()
self.nbit = dec.int32()
self.bitmaps = [dec.bitmap(self.nbit) for i in range(n)]
def clone(self):
enc = Encoder()
self.encode(enc)
return Stackmap(Decoder(enc.b))
def add(self, bitmap):
nbit, b2 = 0, bitmap
while b2 != 0:
nbit += 1
b2 >>= 1
self.nbit = max(nbit, self.nbit)
for i, b2 in enumerate(self.bitmaps):
if bitmap == b2:
return i
self.bitmaps.append(bitmap)
return len(self.bitmaps)-1
def sort(self):
s = sorted((b, i) for i, b in enumerate(self.bitmaps))
self.bitmaps = [b for b, i in s]
return [i for b, i in s]
def encode(self, enc, compact=False):
enc.int32(len(self.bitmaps))
if compact:
enc.uint8(self.nbit)
combined = 0
for i, b in enumerate(self.bitmaps):
combined |= b << (i * self.nbit)
enc.bitmap(combined, len(self.bitmaps) * self.nbit)
else:
enc.int32(self.nbit)
for b in self.bitmaps:
enc.bitmap(b, self.nbit)
class PCData:
def __init__(self):
self.pcdata = []
def encode(self, enc):
last = (0, 0)
for e in self.pcdata:
enc.uvarint(e[0] - last[0])
enc.svarint(e[1] - last[1])
last = e
enc.uint8(0)
def huffSize(self, pcHuff, valHuff):
bits = 0
lastPC = 0
for pc, val in self.pcdata:
bits += pcHuff[pc - lastPC][1] + valHuff[val][1]
lastPC = pc
return (bits + 7) // 8
def grSize(self, pcHuff, n):
bits = 0
lastPC = 0
for pc, val in self.pcdata:
bits += pcHuff[pc - lastPC][1]
lastPC = pc
bits += grSize(val + 1, n)
return (bits + 7) // 8
def grSize(val, n):
"""The number of bits in the Golomb-Rice coding of val in base 2^n."""
return 1 + (val >> n) + n
class Decoder:
def __init__(self, b):
self.b = memoryview(b)
def int32(self):
b = self.b
self.b = b[4:]
return b[0] + (b[1] << 8) + (b[2] << 16) + (b[3] << 24)
def bitmap(self, nbits):
bitmap = 0
nbytes = (nbits + 7) // 8
for i in range(nbytes):
bitmap = bitmap | (self.b[i] << (i*8))
self.b = self.b[nbytes:]
return bitmap
class Encoder:
def __init__(self):
self.b = bytearray()
def uint8(self, i):
self.b.append(i)
def int32(self, i):
self.b.extend([i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF, (i>>24)&0xFF])
def bitmap(self, bits, nbits):
for i in range((nbits + 7) // 8):
self.b.append((bits >> (i*8)) & 0xFF)
def uvarint(self, v):
if v < 0:
raise ValueError("negative unsigned varint", v)
while v > 0x7f:
self.b.append((v & 0x7f) | 0x80)
v >>= 7
self.b.append(v)
def svarint(self, v):
ux = v << 1
if v < 0:
ux = ~ux
self.uvarint(ux)
def parse(stream):
import parseasm
objs = parseasm.parse(stream)
fns = []
for obj in objs.values():
if not isinstance(obj, parseasm.Func):
continue
fns.append(obj)
obj.regMaps = [] # [(pc, register bitmap)]
regMap = None
for inst in obj.insts:
if inst.asm.startswith("FUNCDATA\t$"+FUNCDATA+", "):
regMapSym = inst.asm.split(" ")[1][:-4]
regMap = Stackmap(Decoder(objs[regMapSym].data))
elif inst.asm.startswith("PCDATA\t$"+PCDATA+", "):
idx = int(inst.asm.split(" ")[1][1:])
obj.regMaps.append((inst.pc, regMap.bitmaps[idx]))
return fns
def genStackMaps(fns, padToByte=True, dedup=True, sortBitmaps=False):
regMapSet = {}
for fn in fns:
# Create pcdata and register map for fn.
fn.pcdataRegs = PCData()
fn.funcdataRegMap = Stackmap()
for (pc, bitmap) in fn.regMaps:
fn.pcdataRegs.pcdata.append((pc, fn.funcdataRegMap.add(bitmap)))
if sortBitmaps:
remap = regMap.sort()
pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata]
# Encode and dedup register maps.
if dedup:
e = Encoder()
fn.funcdataRegMap.encode(e, not padToByte)
regMap = bytes(e.b)
if regMap in regMapSet:
fn.funcdataRegMap = regMapSet[regMap]
else:
regMapSet[regMap] = fn.funcdataRegMap
else:
regMapSet[fn] = fn.funcdataRegMap
return regMapSet.values()
def likeStackMap(fns, padToByte=True, dedup=True, sortBitmaps=None, huffmanPcdata=False, grPcdata=False):
regMapSet = set()
regMaps = bytearray()
pcdatas = [] #Encoder()
extra = 0
for fn in fns:
# Create pcdata and register map for fn.
pcdata = PCData()
regMap = Stackmap()
if sortBitmaps == "freq":
# Pre-populate regMap in frequency order.
regMapFreq = collections.Counter()
for pc, bitmap in fn.regMaps:
regMapFreq[bitmap] += 1
for bitmap, freq in sorted(regMapFreq.items(), key=lambda item: item[1], reverse=True):
regMap.add(bitmap)
for pc, bitmap in fn.regMaps:
pcdata.pcdata.append((pc, regMap.add(bitmap)))
if sortBitmaps == "value":
remap = regMap.sort()
pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata]
pcdatas.append(pcdata)
# Encode register map and dedup.
e = Encoder()
regMap.encode(e, not padToByte)
regMap = bytes(e.b)
if not dedup or regMap not in regMapSet:
regMapSet.add(regMap)
regMaps.extend(regMap)
extra += 8 + 4 # funcdata pointer, pcdata table offset
# Encode pcdata.
pcdataEnc = Encoder()
if huffmanPcdata or grPcdata:
pcDeltas, _ = countDeltas(fns)
pcdataHist = collections.Counter()
for pcdata in pcdatas:
for _, idx in pcdata.pcdata:
pcdataHist[idx] += 1
pcHuff = huffman(pcDeltas)
pcdataHuff = huffman(pcdataHist)
size = 0
for pcdata in pcdatas:
if huffmanPcdata:
size += pcdata.huffSize(pcHuff, pcdataHuff)
elif grPcdata:
size += pcdata.grSize(pcHuff, grPcdata)
pcdataEnc.b = "\0" * size # Whatever
else:
for pcdata in pcdatas:
pcdata.encode(pcdataEnc)
return {"gclocals": len(regMaps), "pcdata": len(pcdataEnc.b), "extra": extra}
def filterLiveToDead(fns):
# Only emit pcdata if something becomes newly-live (this is a
# lower bound on what the "don't care" optimization could
# achieve).
for fn in fns:
newRegMaps = []
prevBitmap = 0
for (pc, bitmap) in fn.regMaps:
if bitmap is None:
newRegIdx.append((pc, None))
prevBitmap = 0
continue
if bitmap & ~prevBitmap != 0:
# New bits set.
newRegMaps.append((pc, bitmap))
prevBitmap = bitmap
fn.regMaps = newRegMaps
def total(dct):
dct["total"] = 0
dct["total"] = sum(dct.values())
return dct
def iterDeltas(regMaps):
prevPC = prevBitmap = 0
for (pc, bitmap) in regMaps:
pcDelta = pc - prevPC
prevPC = pc
if bitmap is None:
bitmapDelta = None
prevBitmap = 0
else:
bitmapDelta = bitmap ^ prevBitmap
prevBitmap = bitmap
yield pcDelta, bitmapDelta
def countMaps(fns):
maps = collections.Counter()
for fn in fns:
for _, bitmap in fn.regMaps:
maps[bitmap] += 1
return maps
def countDeltas(fns):
pcDeltas, deltas = collections.Counter(), collections.Counter()
# This actually spreads out the head of the distribution quite a bit
# because things are more likely to die in clumps and at the same time
# as something else becomes live.
#filterLiveToDead(fns)
for fn in fns:
for pcDelta, bitmapDelta in iterDeltas(fn.regMaps):
pcDeltas[pcDelta] += 1
deltas[bitmapDelta] += 1
return pcDeltas, deltas
def huffman(counts, streamAlign=1):
code = [(count, val) for val, count in counts.items()]
radix = 2**streamAlign
while len(code) > 1:
code.sort(key=lambda x: x[0], reverse=True)
if len(code) < radix:
children, code = code, []
else:
children, code = code[-radix:], code[:-radix]
code.append((sum(child[0] for child in children),
[child[1] for child in children]))
tree = {}
def mktree(node, codeword, bits):
if isinstance(node, list):
for i, child in enumerate(node):
mktree(child, (codeword << streamAlign) + i, bits + streamAlign)
else:
tree[node] = (codeword, bits)
mktree(code[0][1], 0, 0)
return tree
def huffmanCoded(fns, streamAlign=1):
pcDeltas, maskDeltas = countDeltas(fns)
hPCs = huffman(pcDeltas, streamAlign)
hBitmaps = huffman(maskDeltas, streamAlign)
pcdataBits = 0
extra = 0
for fn in fns:
for pcDelta, bitmapDelta in iterDeltas(fn.regMaps):
pcdataBits += hPCs[pcDelta][1] + hBitmaps[bitmapDelta][1]
pcdataBits = (pcdataBits + 7) &~ 7 # Byte align
extra += 4 # PCDATA
return {"pcdata": (pcdataBits + 7) // 8, "extra": extra}
fns = parse(sys.stdin)
if True:
print(total(likeStackMap(fns)))
# Linker dedup of gclocals reduces gclocals by >2X
#print(total(likeStackMap(fns, dedup=False)))
#print(total(likeStackMap(fns, sortBitmaps="value")))
# 'total': 529225, 'pcdata': 292703, 'gclocals': 77558, 'extra': 158964
print(total(likeStackMap(fns, huffmanPcdata=True)))
print(total(likeStackMap(fns, huffmanPcdata=True, sortBitmaps="freq")))
for n in range(0, 8):
print(n, total(likeStackMap(fns, grPcdata=n, sortBitmaps="freq")))
#print(total(likeStackMap(fns, compactBitmap=True)))
# 'total': 407999, 'pcdata': 302023, 'extra': 105976
print(total(huffmanCoded(fns)))
print(total(huffmanCoded(fns, streamAlign=8)))
# Only emitting on newly live reduces pcdata by 42%, gclocals by 10%
filterLiveToDead(fns)
print(total(likeStackMap(fns)))
if False:
# What do the bitmaps look like?
counts = countMaps(fns)
for bitmap, count in counts.items():
print(count, bin(bitmap))
if False:
# What do the bitmap changes look like?
_, deltas = countDeltas(fns)
for delta, count in deltas.items():
print(count, bin(delta))
if False:
# PC delta histogram
pcDeltaHist, _ = countDeltas(fns)
for delta, count in pcDeltaHist.items():
print(count, delta)