-
Notifications
You must be signed in to change notification settings - Fork 32
/
optimization.py
195 lines (158 loc) · 6.84 KB
/
optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import optax
from optax import GradientTransformation
from optax._src.base import NO_PARAMS_MSG
import jax
import chex
import jax.numpy as jnp
import functools
from optax._src import numerics, wrappers
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
from mreserve.checkpoint import f32_to_bf16, bf16_to_f32
from optax._src.factorized import _factored_dims
import numpy as np
from typing import NamedTuple, Any
class ScaleByAdamState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array
mu: optax.Updates
nu: optax.Updates
def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
bias_correction = 1 - decay ** count
return jax.tree_map(lambda t: t / bias_correction.astype(t.dtype), moment)
#########################
# Bfloat16
# using this as the sign bit
# cube root to exchange range for mantissa precision
# i think this encoding probably isn't as good as I would want were i to do this again.
# the issue is for really small numbers we almost never choose the negative option
missing_precision = 1 + (1 / 2 ** 9)
def _unsigned_bfloat16_decode(v):
v_abs = jnp.abs(v).astype(jnp.float32)
v_abs = jax.lax.select(v >= 0, v_abs, v_abs * missing_precision)
return jnp.cbrt(v_abs)
def _unsigned_bfloat16_encode(v):
v_pow = jnp.power(v, 3)
v_bf = v_pow.astype(jnp.bfloat16)
v_bf32 = v_bf.astype(jnp.float32)
err0 = jnp.abs(v_bf32 - v_pow)
err1 = jnp.abs(v_bf32 * missing_precision - v_pow)
return jax.lax.select(err0 < err1, v_bf, -v_bf)
def scale_by_bfloat16_adam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
use_bfloat16=True,
do_bias_correction=True,
) -> GradientTransformation:
"""
Scales by bfloat16 adam
:param b1:
:param b2:
:param eps:
:param eps_root:
:param use_bfloat16:
:param do_bias_correction:
:return:
"""
if not use_bfloat16:
assert do_bias_correction
return optax.scale_by_adam(b1, b2, eps, eps_root)
_init = functools.partial(jnp.zeros_like, dtype=jnp.bfloat16)
def init_fn(params):
running_m = jax.tree_map(_init, params)
running_v = jax.tree_map(_init, params)
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=running_m, nu=running_v)
def _momentum_update(grad, current_m):
# Cast up here
current_m = current_m.astype(jnp.float32)
next_m = (1 - b1) * grad + b1 * current_m
return next_m
def _secondorder_update(grad, current_v):
current_v_dec = _unsigned_bfloat16_decode(current_v)
next_v = (1 - b2) * jnp.square(grad) + b2 * current_v_dec
return next_v
def update_fn(updates, state, params=None):
del params
next_m = jax.tree_multimap(_momentum_update, updates, state.mu)
next_m_enc = jax.tree_map(lambda x: x.astype(jnp.bfloat16), next_m)
next_v = jax.tree_multimap(_secondorder_update, updates, state.nu)
next_v_enc = jax.tree_map(_unsigned_bfloat16_encode, next_v)
count_inc = numerics.safe_int32_increment(state.count)
if do_bias_correction:
next_m = _bias_correction(next_m, b1, count_inc)
next_v = _bias_correction(next_v, b2, count_inc)
# Apply updates
updates = jax.tree_multimap(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), next_m, next_v)
return updates, ScaleByAdamState(count=count_inc, mu=next_m_enc, nu=next_v_enc)
return GradientTransformation(init_fn, update_fn)
def lr_scale_linearwarmup_cosinedecay(num_warmup_steps, num_train_steps, final_lr_scale=0.1):
"""
:param num_warmup_steps: Linear warmup for this many steps
:param num_train_steps: Cosine decay for num_train_steps - num_warmup_steps
:param final_lr_scale: We will end at this * learning_rate
:return:
"""
assert num_warmup_steps <= num_train_steps
def schedule(step):
warmup_scale = step / num_warmup_steps
post_warmup_scale = (step - num_warmup_steps) / (num_train_steps - num_warmup_steps + 1.0)
post_warmup_scale = jnp.minimum(post_warmup_scale, 1.0)
# linear -> cosine
post_warmup_scale = (1.0 - (1.0 - jnp.cos(jnp.pi * post_warmup_scale)) / 2.0)
post_warmup_scale = final_lr_scale + (1.0 - final_lr_scale) * post_warmup_scale
return jax.lax.select(step < num_warmup_steps, warmup_scale, post_warmup_scale)
return schedule
def lr_scale_linearwarmup_lineardecay(num_warmup_steps, num_train_steps):
"""
:param num_warmup_steps: Linear warmup for this many steps
:param num_train_steps: Linear decay for num_train_steps - num_warmup_steps
:param final_lr_scale: We will end at this * learning_rate
:return:
"""
assert num_warmup_steps <= num_train_steps
def schedule(step):
warmup_scale = step / num_warmup_steps
post_warmup_scale = (step - num_warmup_steps) / (num_train_steps - num_warmup_steps + 1.0)
post_warmup_scale = 1.0 - jnp.minimum(post_warmup_scale, 1.0)
return jax.lax.select(step < num_warmup_steps, warmup_scale, post_warmup_scale)
return schedule
def construct_train_state(opt_config, model, params, return_chainables=False):
"""
:param optimizer_params: Dict like
{
learning_rate: 0.0001
num_train_steps: 60000 # 5 epochs
num_warmup_steps: 10000
weight_decay_rate: 0.1
beta_2: 0.98
clip_norm: 0.0
adafactor: False
use_bfloat16_adam: True
}
:return:
"""
opt = scale_by_bfloat16_adam(b1=opt_config.get('beta_1', 0.9),
b2=opt_config.get('beta_2', 0.98),
eps=opt_config.get('eps', 1e-8),
use_bfloat16=opt_config.get('use_bfloat16_adam', True),
do_bias_correction=opt_config.get('do_bias_correction', False),
)
chainables = [
opt,
optax.add_decayed_weights(weight_decay=opt_config['weight_decay_rate'],
mask=lambda p: jax.tree_map(lambda x: x.ndim > 1, p),
),
optax.scale_by_schedule(lr_scale_linearwarmup_cosinedecay(num_warmup_steps=opt_config['num_warmup_steps'],
num_train_steps=opt_config['num_train_steps'],
final_lr_scale=opt_config.get('final_lr_scale', 0.02),
)),
optax.scale(-opt_config['learning_rate']),
]
if return_chainables:
return chainables
tx = optax.chain(*chainables)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)