-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathflow_model.py
136 lines (110 loc) · 4.78 KB
/
flow_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("../../")
from general.mutils import get_param_val, create_transformer_mask, create_channel_mask
from layers.flows.flow_model import FlowModel
from layers.flows.activation_normalization import ActNormFlow
from layers.flows.permutation_layers import InvertibleConv
from layers.flows.coupling_layer import CouplingLayer
from layers.flows.mixture_cdf_layer import MixtureCDFCoupling
from layers.categorical_encoding.mutils import create_encoding
class FlowSetModeling(FlowModel):
def __init__(self, model_params, dataset_class):
super().__init__(layers=None, name="Set Modeling Flow")
self.model_params = model_params
self.dataset_class = dataset_class
self.set_size = self.model_params["set_size"]
self.vocab_size = self.dataset_class.get_vocab_size(self.set_size)
self._create_layers()
self.print_overview()
def _create_layers(self):
self.latent_dim = self.model_params["categ_encoding"]["num_dimensions"]
model_func = lambda c_out : CouplingTransformerNet(c_in=self.latent_dim,
c_out=c_out,
num_layers=self.model_params["coupling_hidden_layers"],
hidden_size=self.model_params["coupling_hidden_size"])
self.model_params["categ_encoding"]["flow_config"]["model_func"] = model_func
self.model_params["categ_encoding"]["flow_config"]["block_type"] = "Transformer"
self.encoding_layer = create_encoding(self.model_params["categ_encoding"],
dataset_class=self.dataset_class,
vocab_size=self.vocab_size)
num_flows = self.model_params["coupling_num_flows"]
if self.latent_dim > 1:
coupling_mask = CouplingLayer.create_channel_mask(self.latent_dim,
ratio=self.model_params["coupling_mask_ratio"])
coupling_mask_func = lambda flow_index : coupling_mask
else:
coupling_mask = CouplingLayer.create_chess_mask()
coupling_mask_func = lambda flow_index : coupling_mask if flow_index%2==0 else 1-coupling_mask
layers = []
for flow_index in range(num_flows):
layers += [
ActNormFlow(self.latent_dim),
InvertibleConv(self.latent_dim),
MixtureCDFCoupling(c_in=self.latent_dim,
mask=coupling_mask_func(flow_index),
model_func=model_func,
block_type="Transformer",
num_mixtures=self.model_params["coupling_num_mixtures"])
]
self.flow_layers = nn.ModuleList([self.encoding_layer] + layers)
def forward(self, z, ldj=None, reverse=False, length=None, **kwargs):
if length is not None:
kwargs["src_key_padding_mask"] = create_transformer_mask(length)
kwargs["channel_padding_mask"] = create_channel_mask(length)
return super().forward(z, ldj=ldj, reverse=reverse, length=length, **kwargs)
def get_inner_activations(self, z, length=None, return_names=False, **kwargs):
if length is not None:
kwargs["length"] = length
kwargs["src_key_padding_mask"] = create_transformer_mask(length)
kwargs["channel_padding_mask"] = create_channel_mask(length)
out_per_layer = []
layer_names = []
for layer_index, layer in enumerate(self.flow_layers):
z = self._run_layer(layer, z, reverse=False, **kwargs)[0]
out_per_layer.append(z.detach())
layer_names.append(layer.__class__.__name__)
if not return_names:
return out_per_layer
else:
return out_per_layer, layer_names
def initialize_data_dependent(self, batch_list):
# Batch list needs to consist of tuples: (z, kwargs)
print("Initializing data dependent...")
with torch.no_grad():
for batch, kwargs in batch_list:
kwargs["src_key_padding_mask"] = create_transformer_mask(kwargs["length"])
kwargs["channel_padding_mask"] = create_channel_mask(kwargs["length"])
for layer_index, layer in enumerate(self.flow_layers):
batch_list = FlowModel.run_data_init_layer(batch_list, layer)
class CouplingTransformerNet(nn.Module):
def __init__(self, c_in, c_out, num_layers, hidden_size):
super().__init__()
self.input_layer = nn.Sequential(
nn.Linear(c_in, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size)
)
self.transformer_layers = nn.ModuleList([
nn.TransformerEncoderLayer(hidden_size,
nhead=4,
dim_feedforward=2*hidden_size,
dropout=0.0,
activation='gelu') for _ in range(num_layers)
])
self.output_layer = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, c_out)
)
def forward(self, x, src_key_padding_mask, **kwargs):
x = x.transpose(0, 1) # Transformer layer expects [Sequence length, Batch size, Hidden size]
x = self.input_layer(x)
for transformer in self.transformer_layers:
x = transformer(x, src_key_padding_mask=src_key_padding_mask)
x = self.output_layer(x)
x = x.transpose(0, 1)
return x