Skip to content

Commit

Permalink
Merge pull request #41 from CortexFoundation/wlt
Browse files Browse the repository at this point in the history
update train mnist
  • Loading branch information
wlt-cortex authored Jun 17, 2020
2 parents bfb239d + 08dde7a commit 3470799
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ccpp.yml

This file was deleted.

23 changes: 19 additions & 4 deletions tests/mrt/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@
from mrt import conf, utils

import numpy as np
import argparse

parser = argparse.ArgumentParser(description='Mnist Traning')
parser.add_argument('--cpu', default=False, action='store_true',
help='whether enable cpu (default use gpu)')
parser.add_argument('--gpu-id', type=int, default=0,
help='gpu device id')
parser.add_argument('--net', type=str, default='',
help='choose available networks, optional: lenet, mlp')

args = parser.parse_args()

def load_fname(version, suffix=None, with_ext=False):
suffix = "."+suffix if suffix is not None else ""
prefix = "{}/mnist_{}{}".format(conf.MRT_MODEL_ROOT, version, suffix)
version = "_"+version if version is not None else ""
prefix = "{}/mnist{}{}".format(conf.MRT_MODEL_ROOT, version, suffix)
return utils.extend_fname(prefix, with_ext)

def data_xform(data):
Expand All @@ -25,10 +37,12 @@ def data_xform(data):
train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = mx.gluon.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)

version = ''
version = args.net
print ("Training {} Mnist".format(version))

# Set the gpu device id
ctx = mx.gpu(0)
ctx = mx.cpu() if args.cpu else mx.gpu(args.gpu_id)
print ("Using device: {}".format(ctx))

def train_mnist():
# Select a fixed random seed for reproducibility
Expand Down Expand Up @@ -70,6 +84,8 @@ def train_mnist():
nn.Dense(64, activation='relu'),
nn.Dense(10, activation=None) # loss function includes softmax already, see below
)
else:
assert False

# Random initialize all the mnist model parameters
net.initialize(mx.init.Xavier(), ctx=ctx)
Expand Down Expand Up @@ -118,5 +134,4 @@ def train_mnist():
fout.write(sym.tojson())
net.collect_params().save(param_file)

print ("Test mnist", version)
train_mnist()

0 comments on commit 3470799

Please sign in to comment.