From 054c969bdcb0345834ce04e7d171937784c2419c Mon Sep 17 00:00:00 2001 From: mratsim Date: Sun, 29 Apr 2018 19:53:04 +0200 Subject: [PATCH] TrainableLayer is now a concept instead of inehritable object: https://github.com/nim-lang/Nim/issues/7713 --- src/nn_dsl/dsl_types.nim | 15 ++++++++++----- src/nn_dsl/nn_dsl.nim | 15 ++++++++++++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/nn_dsl/dsl_types.nim b/src/nn_dsl/dsl_types.nim index 0f67bbfca..2ed34095f 100644 --- a/src/nn_dsl/dsl_types.nim +++ b/src/nn_dsl/dsl_types.nim @@ -55,13 +55,18 @@ type ##################################################### # Todo: move that in NN part - TrainableLayer*[TT] = object of RootObj + # We have to use concepts here: non-ref object inheritance + # doesn't work properly: https://github.com/nim-lang/Nim/issues/7713 + TrainableLayer*[TT] = concept layer + layer.weight is Variable[TT] + layer.bias is Variable[TT] + + Conv2DLayer*[TT] = object + weight*: Variable[TT] + bias*: Variable[TT] + LinearLayer*[TT] = object weight*: Variable[TT] bias*: Variable[TT] - - Conv2DLayer*[T] = object of TrainableLayer[T] - LinearLayer*[T] = object of TrainableLayer[T] - proc hash*(x: NimNode): Hash = assert x.kind == nnkIdent diff --git a/src/nn_dsl/nn_dsl.nim b/src/nn_dsl/nn_dsl.nim index f0276213a..6f9624d14 100644 --- a/src/nn_dsl/nn_dsl.nim +++ b/src/nn_dsl/nn_dsl.nim @@ -2,7 +2,10 @@ # Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). # This file may not be copied, modified, or distributed except according to those terms. -import dsl_core, dsl_types +import + macros, + dsl_core, dsl_types, + ../nn/nn export network, @@ -20,3 +23,13 @@ proc flatten*(s: openarray[int]): int {.inline.}= result = 1 for val in s: result *= val + +func optimizer*[M; U: SomeReal](model: M, optimizer: typedesc[Optimizer[U]], learning_rate: U): Optimizer[U] = + result.params = @[] + result.lr = learning_rate + + for layer in fields(model): + when layer is TrainableLayer: + result.params.add layer.weight + result.params.add layer.bias +