Apache MXNet as a backend for Keras 2
In a previous post, we saw how Apache MXNet could be used as backend for Keras 1.2 — and how fast it was.
A few months, we took an early look at running Keras with Apache MXNet as its backend. Things were pretty beta at the…medium.com
Lo and behold, you can now use it with Keras 2 (with some restrictions). This is great news both for the Keras community and for the MXNet community, so how about a collective thumbs up to @sandeep-krishnamurthy, @jiajiechen, @karan6181, @roywei, @kalyc for their contribution.
This post will show how you how to install this version of Keras, tell you what is supported and what isn’t at the moment and of course how to run your trainings faster by using multiple GPUs.

Installation
Let’s fire up an EC2 instance with the Deep Learning AMI. I’ll use the Conda version which lets us easily manage different Python environments… and avoid making a huge mess of everything ;)
$ conda create -n mxnet-keras
$ source activate mxnet-keras
$ pip3 install mxnet-cu90 --upgrade
$ pip3 install keras-mxnet --upgrade
$ python3
>>> import mxnet, keras
>>> print ("%s %s" % (mxnet.__version__, keras.__version__))
1.2.0 2.1.6
Good. Now let’s make sure that MXNet is actually set as the backend for Keras.
$ cat ~/.keras/keras.json
{
"backend": "mxnet",
"image_data_format": "channels_first"
}
Setting ‘image_data_format’ to ‘channels_first’ will make MXNet training faster. When working with image data, the input shape can either be ‘channels_first’, i.e. (number of channels, height, width), or ‘channels_last’, i.e. (height, width, number of channels).
For MNIST, this would either be (1, 28, 28) or (28, 28, 1) : one channel (black and white pictures), 28 pixels by 28 pixels. For ImageNet, it would be (3, 224, 224) or (224, 224, 3): three channels (red, green and blue), 224 pixels by 224 pixels.
We’re ready to play!
Supported features
This is still work in progress. You can check the super detailed release notes (great work, team) to see what’s supported and what isn’t.
In a nutshell (copying from the releases notes):
- Supports Convolutional Neural Network (CNN) and experimental Recurrent Neural Network (RNN) training and inference.
- Supports high performance, distributed Multi-GPU training of CNN and RNN networks.
- Supports exporting native MXNet Model from Keras-MXNet trained model. Enabling faster research with Keras interface and high performance, large scale inference in production with the native MXNet engine. You can use all language bindings of MXNet (Scala/Python/Julia/R/Perl) for inference on the exported model.
- Add Keras benchmarking utility for performing CNN and RNN benchmarks with standard networks and datasets. Supports benchmarking on CPU, one GPU and multi-GPU distributed training.
A few comments:
- Of course, multi-GPU training is extremely important. We’ll see in a second how easy it is to add it to any existing Keras script.
- Using Keras with MXNet brings a major performance boost over other backends. Detailed benchmarks are available, they’re most definitely worth a read ;)
- The third point is extremely important IMHO. You can use Keras for fast experimentation, possibly reusing or tweaking models from the vast collection that is available out there. Then, you can export it as a native MXNet model and use it in production with MXNet only, which will speed things up further because MXNet is implemented in C++. Very useful feature. You’ll find a complete example here.
Multi-GPU training
In the Keras 1.2 version, this was available by building an MXNet-style context.
NUM_GPU = 4
gpu_list = []
for i in range(NUM_GPU):
gpu_list.append('gpu(%d)' % i)
model.compile(
loss='categorical_crossentropy',
optimizer=SGD(),
metrics=['accuracy'],
context=gpu_list)
Now all we have to do is:
from keras.utils import multi_gpu_model
...
model = multi_gpu_model(model, gpus=4)
model.compile(loss='categorical_crossentropy',
optimizer=SGD(),
metrics=['accuracy'])
This is definitely more convenient. Just wrap the model with multi_gpu_model() before compiling it and voila!
Example
Using a p3.16xlarge instance, let’s clone the keras-mxnet repo.
$ git clone https://github.com/awslabs/keras-apache-mxnet.git
First, we’re going to run the cifar10_resnet.py script, which trains ResNet20v1 on the CIFAR-10 data set. Unsurprisingly, we only use a single GPU ;)

One epoch takes 1309 seconds (almost 22 minutes).
1563/1563 [==============================] - 1309s 838ms/step - loss: 1.6875 - acc: 0.4412 - val_loss: 1.7708 - val_acc: 0.4323
Now, let’s run cifar10_resnet_multi_gpu.py script on 8 GPUs :) All it takes is adding the couple of lines mentioned above.

One epoch now takes 190 seconds.
1563/1563 [==============================] - 190s 122ms/step - loss: 1.2006 - acc: 0.6315 - val_loss: 1.2958 - val_acc: 0.61
1309/190 = 6.88. Not quite 8x speedup, but pretty close!
That’s it for today. I’m really excited about the combination of Keras and MXNet. Flexibility and speed! Please give it a try.
Happy to answer questions here or on Twitter. For more content, please feel free to check out my YouTube channel.