MXNet Gluon学习笔记

Tensorflow目前依然是公认的最主流的框架,在各方面稳定性能都得到认可。而其它新框架的稳定性还有待提高。如果做科研的话,可以多尝试新框架。

Gluon在PyTorch之后推出,后推出的框架可以吸收之前框架的优点,剔除已知的缺点。
下面是Gluon版的线性拟合代码,参考https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter02_supervised-learning/linear-regression-gluon.ipynb
#!/root/anaconda3/bin/python
import os

import mxnet as mx
from mxnet import nd, autograd, gluon

data_ctx = mx.cpu()
model_ctx = mx.cpu()

num_inputs = 2
num_outputs = 1
num_examples = 10000

def real_fn(X):
    return 2 * X[:, 0] - 3.4 * X[:, 1] + 4.2

X = nd.random_normal(shape=(num_examples, num_inputs))
noise = 0.01 * nd.random_normal(shape=(num_examples,))
y = real_fn(X) + noise

batch_size = 4
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(X, y),
                                      batch_size=batch_size, shuffle=True)

net = gluon.nn.Dense(1, in_units=2)

print(net.weight)
print(net.bias)

net.collect_params()

type(net.collect_params())

net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)

example_data = nd.array([[4,7]])
net(example_data)

print(net.weight.data())
print(net.bias.data())

net = gluon.nn.Dense(1)
net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)

square_loss = gluon.loss.L2Loss()

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.0001})

epochs = 10
loss_sequence = []
num_batches = num_examples / batch_size

for e in range(epochs):
    cumulative_loss = 0
    # inner loop
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(model_ctx)
        label = label.as_in_context(model_ctx)
        with autograd.record():
            output = net(data)
            loss = square_loss(output, label)
        loss.backward()
        trainer.step(batch_size)
        cumulative_loss += nd.mean(loss).asscalar()
    print("Epoch %s, loss: %s" % (e, cumulative_loss / num_examples))
    loss_sequence.append(cumulative_loss)


发表于:2017-11-17 15:37:56

原文链接(转载请保留): http://www.multisilicon.com/blog/a25324161.html

友情链接: MICROIC
首页