Skip to content

Commit

Permalink
docs: stepwise lr/noam example
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki authored and albertz committed Jan 6, 2021
1 parent 4a1f172 commit face0c3
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions docs/configuration_reference/optimizer_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,55 @@ Optimizer Settings
accum_grad_multiple_step
An integer specifying the number of updates to stack the gradient, called "gradient accumulation".

dynamic_learning_rate
This can be set as either a dictionary or a function.

When setting a dictionary, a cyclic learning rate can be implemented by setting the parameters
``interval`` and ``decay``.
The global learning rate is then multiplied by ``decay ** (global_step % interval)``

When using a custom function, the passed parameters are ``network``, `global_train_step`` and ``learning_rate``.
Do not forget to mark the parameters as variable args and add ``**kwargs`` to keep the config
compatible to future changes.

An example for Noam-style learning rate scheduling would be:

.. code-block:: python
learning_rate = 1 # can be higher, reasonable values may be up to 10 or even more
learning_rate_control = "constant"
def noam(n, warmup_n, model_d):
"""
Noam style learning rate scheduling
(k is identical to the global learning rate)
:param int|float|tf.Tensor n:
:param int|float|tf.Tensor warmup_n:
:param int|float|tf.Tensor model_d:
:return:
"""
from returnn.tf.compat import v1 as tf
model_d = tf.cast(model_d, tf.float32)
n = tf.cast(n, tf.float32)
warmup_n = tf.cast(warmup_n, tf.float32)
return tf.pow(model_d, -0.5) * tf.minimum(tf.pow(n, -0.5), n * tf.pow(warmup_n, -1.5))
def dynamic_learning_rate(*, network, global_train_step, learning_rate, **kwargs):
"""
:param TFNetwork network:
:param tf.Tensor global_train_step:
:param tf.Tensor learning_rate: current global learning rate
:param kwargs:
:return:
"""
WARMUP_N = 25000
MODEL_D = 512
return learning_rate * noam(n=global_train_step, warmup_n=WARMUP_N, model_d=MODEL_D)
gradient_clip
Specify a gradient clipping threshold.

Expand Down

0 comments on commit face0c3

Please sign in to comment.