-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathreplacement_scheduler.py
107 lines (90 loc) · 4.19 KB
/
replacement_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf8 -*-
"""
======================================
Project Name: ner-pytorch
File Name: replacement_scheduler
Author: czh
Create Date: 2020/9/14
--------------------------------------
Change Activity:
======================================
"""
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from transformers.modeling_bert import BertLayer
class BertEncoder(nn.Module):
def __init__(self, config, scc_n_layer=6):
super(BertEncoder, self).__init__()
self.prd_n_layer = config.num_hidden_layers
self.scc_n_layer = scc_n_layer
assert self.prd_n_layer % self.scc_n_layer == 0
self.compress_ratio = self.prd_n_layer // self.scc_n_layer
self.bernoulli = None
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(self.prd_n_layer)])
self.scc_layer = nn.ModuleList([BertLayer(config) for _ in range(self.scc_n_layer)])
def set_replacing_rate(self, replacing_rate):
if not 0 < replacing_rate <= 1:
raise Exception('Replace rate must be in the range (0, 1]!')
self.bernoulli = Bernoulli(torch.tensor([replacing_rate]))
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None,
encoder_attention_mask=None):
all_hidden_states = ()
all_attentions = ()
if self.training:
inference_layers = []
for i in range(self.scc_n_layer):
if self.bernoulli.sample() == 1: # REPLACE
inference_layers.append(self.scc_layer[i])
else: # KEEP the original
for offset in range(self.compress_ratio):
inference_layers.append(self.layer[i * self.compress_ratio + offset])
else: # inference with compressed model
inference_layers = self.scc_layer
for i, layer_module in enumerate(inference_layers):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states,
encoder_attention_mask)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class ConstantReplacementScheduler:
def __init__(self, bert_encoder: BertEncoder, replacing_rate, replacing_steps=None):
self.bert_encoder = bert_encoder
self.replacing_rate = replacing_rate
self.replacing_steps = replacing_steps
self.step_counter = 0
self.bert_encoder.set_replacing_rate(replacing_rate)
def step(self):
self.step_counter += 1
if self.replacing_steps is None or self.replacing_rate == 1.0:
return self.replacing_rate
else:
if self.step_counter >= self.replacing_steps:
self.bert_encoder.set_replacing_rate(1.0)
self.replacing_rate = 1.0
return self.replacing_rate
class LinearReplacementScheduler:
def __init__(self, bert_encoder: BertEncoder, base_replacing_rate, k):
self.bert_encoder = bert_encoder
self.base_replacing_rate = base_replacing_rate
self.step_counter = 0
self.k = k
self.bert_encoder.set_replacing_rate(base_replacing_rate)
def step(self):
self.step_counter += 1
current_replacing_rate = min(self.k * self.step_counter + self.base_replacing_rate, 1.0)
self.bert_encoder.set_replacing_rate(current_replacing_rate)
return current_replacing_rate