-
Notifications
You must be signed in to change notification settings - Fork 12
23 ann.graph
This package contains an implementation of ANNs by using a graph description. The graph allow to declare delayed connections, which can be used to develop recurrent neural networks as LSTMs or Elman.
Graph ANNs are declared as instances of the class ann.graph
. The ANN graph is
a model where nodes are ANN componentes (any object of ann.components
and
objects defined in ann.graph
and ann.graph.blocks
). So, nodes have an input
and an output tokens, in the same way as ANN componentes receive a token and
produce as output another token. Nodes can receive any number of connections as
input and its output can be connected to other multiple nodes. When multiple
input connections are received, they are put together into a
tokens.vector.bunch
instance. The graph implements properly the propagation of
gradients between the nodes, and uses the methods forward
, backprop
and
compute_gradients
of the components in every node.
The graph object is considered itself as an ANN component, allowing to declare graphs where nodes are other graphs. Every graph has two special nodes, 'input' and 'output', which are used to connect the visible parts of the ANN.
In ANN graphs loops are allowed, but they cannot be made of normal connections,
and the concept of delayed connections is introduced (indeed, normal
connections are whose with delay=0
). An ANN graph with delayed connections is
equivalent to the concept of Recurrent Neural Network (RNN in the following).
The training of RNNs is done following the Back-Propagation Trough Time (BPTT)
algorithm. RNNs have an special behavior in forward
method, the graph takes
note of the state (input, output, gradient deltas, ...) for every node, allowing
to take as input the activation in any past instant. The backprop
method
returns a tokens.null
instance, that is, its output is none, this method just
annotates the given input error deltas for a future use. Calling the method
compute_gradients
the error deltas given at backprop
are propagated through
all the space and time, and the weight gradients are computed.
g = ann.graph( [ name ] )
The constructor an optional name
argument.
g:connect( source, dest1, ..., [ delay=0 ] )
This method connects a path of nodes in the graph. It receives as arguments:
-
source
: The first node in the path (source
). It can be an ANN component or the string'input'
. -
dest1
: The second node in the path (dest1
). It can be an ANN component or the string'output'
. -
...
: A variadic list of arguments with zero or more nodes which form the path. Every node in this list can be an ANN component oroutput
. -
delay=0
: The last argument is optional and by default it is zero. This argument is needed to declare delayed connections. All the connections in the path would be declared with the delay given in this argument. Note that normally only the connection between two nodes need to be delayed, and for this purpose the methodg:delayed(source,destination)
has been declared.
For correction, a graph is valid only if the input
node is a source and
output
node is a sink, and every node is reachable from the input
.
The following example shows how to declare a Jordan network.
> g = ann.graph()
> bind = ann.graph.bind()
> out_actf = ann.components.logistic()
> g:connect('input', bind,
ann.components.hyperplane{ input=2, output=4 },
ann.components.actf.logistic(),
ann.components.hyperplane{ input=4, output=1 },
out_actf, 'output')
> g:delayed(out_actf, bind) -- recurrent connection
> g:build()
g:delayed( source, destination )
This is equivalent to g:connect(source, destination, 1)
.
g:show_nodes()
This method is used for debug purposes, and show all the nodes in the given graph. If the graph contains other graphs as nodes, their nodes would be shown recursively. The output indicates the level of recursion, the name of the component in the corresponding node, and the type of the object, as in the following example (extracted from a LSTM test):
# (level) name type
(0) a2 ann.components.actf.log_logistic
(0) parity::output string
(0) LSTM ann.graph +
(1) LSTM::f::layer ann.components.hyperplane
(1) LSTM::input string
(1) LSTM::o::actf ann.components.actf.logistic
(1) LSTM::i::peephole ann.graph.bind
(1) LSTM::o::gate ann.graph.cmul
(1) LSTM::o::layer ann.components.hyperplane
(1) LSTM::f::gate ann.graph.cmul
(1) LSTM::actf ann.components.actf.softsign
(1) LSTM::memory ann.graph.add
(1) LSTM::o::peephole ann.graph.bind
(1) LSTM::i::actf ann.components.actf.logistic
(1) LSTM::cell_input ann.components.hyperplane
(1) LSTM::output string
(1) LSTM::i::gate ann.graph.cmul
(1) LSTM::f::actf ann.components.actf.logistic
(1) LSTM::i::layer ann.components.hyperplane
(1) LSTM::f::peephole ann.graph.bind
(0) l2 ann.components.hyperplane
(0) parity::input string
g:dot_graph(filename)
This method is for debug. It writes to the given filename
a DOT graph which
can be transformed in PDF using graphviz.
g,weights,components = g:build( [ table ] )
See ANN package doc.
output = g:forward( input, [ during_training=false ] )
See ANN package doc.
output = g:backprop( input )
See ANN package doc.
Note that this method changes its default behavior when the graph is a RNN.
In this case, the output of this method is a tokens.null
instance, so it can
be ignored.
table = g:compute_gradients( [table] )
See ANN package doc.
table = g:bptt_backprop()
Forces the BPTT algorithm execution, and returns a table (Lua array) with the delta gradients at the ANN input for every time instant.
table = g:get_bptt_state()
Returns a table with the state of the whole ANN for every time instant.
table = g:get_bptt_state(time)
Returns a table with the state of the whole ANN for the given time instant.
g:reset( [ n ] )
See ANN package doc.
Additionally with the standard behavior, this method reinitializes the BPTT tables, and must be called before starting a new sequence.
boolean = g:get_is_recurrent()
Indicates if the caller graph is recurrent or not.
g:set_bptt_truncation( backstep )
Changes the BPTT algorithm behavior, truncating the gradient computation every
backstep
number of iterations. This value can be math.huge
to indicate an
infinite limit. Besides this value, another usual one is 1
, which transforms
allow to use the algorithm as a kind of on-line learning algorithm.
ann.graph.bind{ [ name=string ], [ input=number ], [ output=number ], [ size=number ] }
ann.graph.add{ [ name=string ], [ input=number ], [ output=number ] }
ann.graph.cmul{ [ name=string ], [ input=number ], [ output=number ] }
ann.graph.index(n, { [ name=string ], [ input=number ], [ output=number ] })
ann.graph.blocks.elman{ [ name=string ], [ input=number ],
[ output=number ], [ actf=string ] })
ann.graph.blocks.lstm{ [ name=string ], [ input=number ],
[ output=number ], [ actf=string ],
[ peepholes=true ], [ input_gate=true ],
[ forget_gate=true], [ output_gate=true ] })