Skip to content

Commit

Permalink
TrainableLayer is now a concept instead of inehritable object: nim-la…
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Apr 29, 2018
1 parent bc83b5f commit 054c969
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/nn_dsl/dsl_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ type

#####################################################
# Todo: move that in NN part
TrainableLayer*[TT] = object of RootObj
# We have to use concepts here: non-ref object inheritance
# doesn't work properly: https://github.com/nim-lang/Nim/issues/7713
TrainableLayer*[TT] = concept layer
layer.weight is Variable[TT]
layer.bias is Variable[TT]

Conv2DLayer*[TT] = object
weight*: Variable[TT]
bias*: Variable[TT]
LinearLayer*[TT] = object
weight*: Variable[TT]
bias*: Variable[TT]

Conv2DLayer*[T] = object of TrainableLayer[T]
LinearLayer*[T] = object of TrainableLayer[T]


proc hash*(x: NimNode): Hash =
assert x.kind == nnkIdent
Expand Down
15 changes: 14 additions & 1 deletion src/nn_dsl/nn_dsl.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0).
# This file may not be copied, modified, or distributed except according to those terms.

import dsl_core, dsl_types
import
macros,
dsl_core, dsl_types,
../nn/nn

export
network,
Expand All @@ -20,3 +23,13 @@ proc flatten*(s: openarray[int]): int {.inline.}=
result = 1
for val in s:
result *= val

func optimizer*[M; U: SomeReal](model: M, optimizer: typedesc[Optimizer[U]], learning_rate: U): Optimizer[U] =
result.params = @[]
result.lr = learning_rate

for layer in fields(model):
when layer is TrainableLayer:
result.params.add layer.weight
result.params.add layer.bias

0 comments on commit 054c969

Please sign in to comment.