forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.jl
74 lines (60 loc) · 2.26 KB
/
model.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
using Flux
using Flux: onehot, onehotbatch, logitcrossentropy, reset!, throttle
using Statistics: mean
using Random
using Unicode
using Parameters: @with_kw
@with_kw mutable struct Args
lr::Float64 = 1e-3 # learning rate
N::Int = 15 # Number of perceptrons in hidden layer
test_len::Int = 100 # length of test data
langs_len::Int = 0 # Number of different languages in Corpora
alphabet_len::Int = 0 # Total number of characters possible, in corpora
throttle::Int = 10 # throttle timeout
end
function get_processed_data(args)
corpora = Dict()
for file in readdir("corpus")
lang = Symbol(match(r"(.*)\.txt", file).captures[1])
corpus = split(String(read("corpus/$file")), ".")
corpus = strip.(Unicode.normalize.(corpus, casefold=true, stripmark=true))
corpus = filter(!isempty, corpus)
corpora[lang] = corpus
end
langs = collect(keys(corpora))
args.langs_len = length(langs)
alphabet = ['a':'z'; '0':'9'; ' '; '\n'; '_']
args.alphabet_len = length(alphabet)
# See which chars will be represented as "unknown"
unique(filter(x -> x ∉ alphabet, join(vcat(values(corpora)...))))
dataset = [(onehotbatch(s, alphabet, '_'), onehot(l, langs)) for l in langs for s in corpora[l]] |> shuffle
train, test = dataset[1:end-args.test_len], dataset[end-args.test_len+1:end]
return train, test
end
function build_model(args)
scanner = Chain(Dense(args.alphabet_len, args.N, σ), LSTM(args.N, args.N))
encoder = Dense(args.N, args.langs_len)
return scanner, encoder
end
function model(x, scanner, encoder)
state = scanner.(x.data)[end]
reset!(scanner)
encoder(state)
end
function train(; kws...)
# Initialize Hyperparameters
args = Args(; kws...)
# Load Data
train_data, test_data = get_processed_data(args)
@info("Constructing Model...")
scanner, encoder = build_model(args)
loss(x, y) = logitcrossentropy(model(x, scanner, encoder), y)
testloss() = mean(loss(t...) for t in test_data)
opt = ADAM(args.lr)
ps = params(scanner, encoder)
evalcb = () -> @show testloss()
@info("Training...")
Flux.train!(loss, ps, train_data, opt, cb = throttle(evalcb, args.throttle))
end
cd(@__DIR__)
train()