-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathramps.py
29 lines (24 loc) · 876 Bytes
/
ramps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup"""
assert current >= 0 and rampup_length >= 0
if current >= rampup_length:
return 1.0
else:
return current / rampup_length
def cosine_rampdown(current, rampdown_length):
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
# assert 0 <= current <= rampdown_length
assert 0 <= current
if current > rampdown_length:
return 0.0
else:
return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))