The goal of this project is to benchmark perforamance of PyTorch and MXNet on a recurrent network based model using single and multi-GPU configurations. Findings are reported in this [blog post] (http://borealisai.com/2018/02/16/standardizing-a-machine-learning-framework-for-applied-research/).
Hardware and Software Configuration
For easy reproducibility, results have been reported on an EC2 instance using a community Deep Learning image. The results were also run on Borealis AI machines to see how performance varies.
|EC2 instance type||p3.2xlarge|
|GPU type||Tesla V100|
Launch an instance of your choice using the Deep Learning AMI (ami-9ba7c4e1). The results below were posted using a p3.8xlarge instance since we wanted to benchmark on the Volta architecutre with multiple GPUs.
ssh <ec2-instance-ip> ## PYTORCH ## cd pytorch # open run.sh and check parameters are what you need source activate pytorch_p36 ./run.sh example.pkl ## MXNET ## cd mxnet # open run.sh and check parameters are what you need source activate mxnet_p36 ./run.sh example.pkl
Set up a virtual environment and install the following libraries. Note that the commands below install the versions of the framework used in this benchmark.
pip install http://download.pytorch.org/whl/cu90/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl
pip install mxnet-cu90
Navigate to one of the
mxnet directories and run as follows.
cd <framework-name> ./run.sh example.pkl
run.sh is a bash script that wraps the python run script
run.py. To modify the arguments passed open
run.sh and change the values accordingly. The full list of supported arguments are:
python <framework>/run.py --help
Model and hyperparameters
<framework>/model.py for details on the model.
Hyperparameters are as follows:
- Optimizer: Adam
- Learning rate:
0.001. A Factored learning schedule was used, which reduces learning rate by
- Weight initializations: Xavier.normal with
gain=2using both N<sub>in</sub> and N<sub>out</sub>.
- Batch Size: Batch sizes of 256, 512 and 1024 were experimented with to plot convergence rates
Please see the blog above for the results.