Parallel Training

Download this notebook


Please view the README to learn about installing, setting up dependencies, and importing notebooks in Zeppelin


Training neural network models can be a computationally expensive task. In order to speed up the training process, you can choose to train your models in parallel with multiple GPU’s if they are installed on your machine. With deeplearning4j (DL4J), this isn’t a difficult thing to do. In this tutorial we will use the MNIST dataset (dataset of handwritten images) to train a feed forward neural network in parallel with distributed GPU’s.

First you must update your pom.xml file if its configured to use CPU’s by default. The last line of the following

<name>DeepLearning4j Examples Parent</name>
<description>Examples of training different data sets</description>

should be changed to


import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.parallelism.ParallelWrapper;

To obtain the data, we use built-in DataSetIterators for the MNIST with a random seed of 12345. These DataSetIterators can be used to directly feed the data into a neural network.

val batchSize = 128
val mnistTrain = new MnistDataSetIterator(batchSize,true,12345)
val mnistTest = new MnistDataSetIterator(batchSize,false,12345)

Next, we set up the neural network configuration using a convolutional configuration and initialize the model.

val nChannels = 1
val outputNum = 10
val seed = 123

val conf = new NeuralNetConfiguration.Builder()
            .layer(0, new ConvolutionLayer.Builder(5, 5)
                //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                .stride(1, 1)
            .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .layer(2, new ConvolutionLayer.Builder(5, 5)
                //Note that nIn need not be specified in later layers
                .stride(1, 1)
            .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
            .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
            .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below

val model = new MultiLayerNetwork(conf)

Next we need to configure the parallel training with the ParallelWrapper class using the MultiLayerNetwork as the input. The ParallelWrapper will take care of load balancing between different GPUs.

The notion is that the model will be duplicated within the ParallelWrapper. The prespecified number of workers (in this case 2) will then train its own model using its data. After a specified number of iterations (in this case 3), all models will be averaged and workers will receive duplicate models. The training process will then continue in this way until the model is fully trained.

val wrapper = new ParallelWrapper.Builder(model)

To train the model, the fit method of the ParallelWrapper is used directly on the DataSetIterator. Because the ParallelWrapper class handles all the training details behind the scenes, it is very simple to parallelize this process using dl4j.

API Reference

API Reference

Detailed API docs for all libraries including DL4J, ND4J, DataVec, and Arbiter.



Explore sample projects and demos for DL4J, ND4J, and DataVec in multiple languages including Java and Kotlin.



Step-by-step tutorials for learning concepts in deep learning while using the DL4J API.



In-depth documentation on different scenarios including import, distributed training, early stopping, and GPU setup.

Subscribe to the DL4J Newsletter