Skip to content

Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.

License

Notifications You must be signed in to change notification settings

horovod/horovod

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Repository files navigation

Horovod

Logo



Horovod is a distributed deep learning training framework for TensorFlow, Keras, PyTorch, and Apache MXNet. The goal of Horovod is to make distributed deep learning fast and easy to use.

LF AI & Data

Horovod is hosted by theLF AI & Data Foundation(LF AI & Data). If you are a company that is deeply committed to using open source technologies in artificial intelligence, machine, and deep learning, and want to support the communities of open source projects in these domains, consider joining the LF AI & Data Foundation. For details about who's involved and how Horovod plays a role, read the Linux Foundationannouncement.




The primary motivation for this project is to make it easy to take a single-GPU training script and successfully scale it to train across many GPUs in parallel. This has two aspects:

  1. How much modification does one have to make to a program to make it distributed, and how easy is it to run it?
  2. How much faster would it run in distributed mode?

Internally at Uber we found the MPI model to be much more straightforward and require far less code changes than previous solutions such as Distributed TensorFlow with parameter servers. Once a training script has been written for scale with Horovod, it can run on a single-GPU, multiple-GPUs, or even multiple hosts without any further code changes. See theUsagesection for more details.

In addition to being easy to use, Horovod is fast. Below is a chart representing the benchmark that was done on 128 servers with 4 Pascal GPUs each connected by RoCE-capable 25 Gbit/s network:

512-GPU Benchmark

Horovod achieves 90% scaling efficiency for both Inception V3 and ResNet-101, and 68% scaling efficiency for VGG-16. SeeBenchmarksto find out how to reproduce these numbers.

While installing MPI and NCCL itself may seem like an extra hassle, it only needs to be done once by the team dealing with infrastructure, while everyone else in the company who builds the models can enjoy the simplicity of training them at scale.

To install Horovod on Linux or macOS:

  1. InstallCMake

  1. If you've installed TensorFlow fromPyPI,make sure thatg++-5or above is installed. Starting with TensorFlow 2.10 a C++17-compliant compiler likeg++8or above will be required.

    If you've installed PyTorch fromPyPI,make sure thatg++-5or above is installed.

    If you've installed either package fromConda,make sure that thegxx_linux-64Conda package is installed.

  1. Install thehorovodpip package.

    To run on CPUs:

    $ pip install horovod

    To run on GPUs with NCCL:

    $ HOROVOD_GPU_OPERATIONS=NCCL pip install horovod

For more details on installing Horovod with GPU support, readHorovod on GPU.

For the full list of Horovod installation options, read theInstallation Guide.

If you want to use MPI, readHorovod with MPI.

If you want to use Conda, readBuilding a Conda environment with GPU support for Horovod.

If you want to use Docker, readHorovod in Docker.

To compile Horovod from source, follow the instructions in theContributor Guide.

Horovod core principles are based onMPIconcepts such assize,rank, local rank,allreduce,allgather,broadcast,andalltoall.Seethis page for more details.

See these pages for Horovod examples and best practices:

To use Horovod, make the following additions to your program:

  1. Runhvd.init()to initialize Horovod.

  1. Pin each GPU to a single process to avoid resource contention.

    With the typical setup of one GPU per process, set this tolocal rank.The first process on the server will be allocated the first GPU, the second process will be allocated the second GPU, and so forth.

  1. Scale the learning rate by the number of workers.

    Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.

  1. Wrap the optimizer inhvd.DistributedOptimizer.

    The distributed optimizer delegates gradient computation to the original optimizer, averages gradients usingallreduceorallgather,and then applies those averaged gradients.

  1. Broadcast the initial variable states from rank 0 to all other processes.

    This is necessary to ensure consistent initialization of all workers when training is started with random weights or restored from a checkpoint.

  1. Modify your code to save checkpoints only on worker 0 to prevent other workers from corrupting them.

Example using TensorFlow v1 (see theexamplesdirectory for full training examples):

importtensorflowastf
importhorovod.tensorflowashvd


# Initialize Horovod
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process)
config=tf.ConfigProto()
config.gpu_options.visible_device_list=str(hvd.local_rank())

# Build model...
loss=...
opt=tf.train.AdagradOptimizer(0.01*hvd.size())

# Add Horovod Distributed Optimizer
opt=hvd.DistributedOptimizer(opt)

# Add hook to broadcast variables from rank 0 to all other processes during
# initialization.
hooks=[hvd.BroadcastGlobalVariablesHook(0)]

# Make training operation
train_op=opt.minimize(loss)

# Save checkpoints only on worker 0 to prevent other workers from corrupting them.
checkpoint_dir='/tmp/train_logs'ifhvd.rank()==0elseNone

# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
withtf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
config=config,
hooks=hooks)asmon_sess:
whilenotmon_sess.should_stop():
# Perform synchronous training.
mon_sess.run(train_op)

The example commands below show how to run distributed training. SeeRun Horovodfor more details, including RoCE/InfiniBand tweaks and tips for dealing with hangs.

  1. To run on a machine with 4 GPUs:

    $ horovodrun -np 4 -H localhost:4 Python train.py
  2. To run on 4 machines with 4 GPUs each:

    $ horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 Python train.py
  3. To run using Open MPI without thehorovodrunwrapper, seeRunning Horovod with Open MPI.

  4. To run in Docker, seeHorovod in Docker.

  5. To run on Kubernetes, seeHelm Chart,Kubeflow MPI Operator,FfDL,andPolyaxon.

  6. To run on Spark, seeHorovod on Spark.

  7. To run on Ray, seeHorovod on Ray.

  8. To run in Singularity, seeSingularity.

  9. To run in a LSF HPC cluster (e.g. Summit), seeLSF.

  10. To run on Hadoop Yarn, seeTonY.

Gloois an open source collective communications library developed by Facebook.

Gloo comes included with Horovod, and allows users to run Horovod without requiring MPI to be installed.

For environments that have support both MPI and Gloo, you can choose to use Gloo at runtime by passing the--glooargument tohorovodrun:

$ horovodrun --gloo -np 2 Python train.py

Horovod supports mi xing and matching Horovod collectives with other MPI libraries, such asmpi4py, provided that the MPI was built with multi-threading support.

You can check for MPI multi-threading support by querying thehvd.mpi_threads_supported()function.

importhorovod.tensorflowashvd

# Initialize Horovod
hvd.init()

# Verify that MPI multi-threading is supported.
asserthvd.mpi_threads_supported()

frommpi4pyimportMPI
asserthvd.size()==MPI.COMM_WORLD.Get_size()

You can also initialize Horovod with an mpi4py sub-communicator, in which case each sub-communicator will run an independent Horovod training.

frommpi4pyimportMPI
importhorovod.tensorflowashvd

# Split COMM_WORLD into subcommunicators
subcomm=MPI.COMM_WORLD.Split(color=MPI.COMM_WORLD.rank%2,
key=MPI.COMM_WORLD.rank)

# Initialize Horovod
hvd.init(comm=subcomm)

print('COMM_WORLD rank: %d, Horovod rank: %d'%(MPI.COMM_WORLD.rank,hvd.rank()))

Learn how to optimize your model for inference and remove Horovod operations from the graphhere.

One of the unique things about Horovod is its ability to interleave communication and computation coupled with the ability to batch smallallreduceoperations, which results in improved performance. We call this batching feature Tensor Fusion.

Seeherefor full details and tweaking instructions.

Horovod has the ability to record the timeline of its activity, called Horovod Timeline.

Horovod Timeline

Use Horovod timeline to analyze Horovod performance. Seeherefor full details and usage instructions.

Selecting the right values to efficiently make use of Tensor Fusion and other advanced Horovod features can involve a good amount of trial and error. We provide a system to automate this performance optimization process called autotuning,which you can enable with a single command line argument tohorovodrun.

Seeherefor full details and usage instructions.

Horovod allows you to concurrently run distinct collective operations in different groups of processes taking part in one distributed training. Set uphvd.process_setobjects to make use of this capability.

SeeProcess Setsfor detailed instructions.

  1. Run distributed training in Microsoft Azure usingBatch AI and Horovod.
  2. Distributed model training using Horovod.

Send us links to any user guides you want to publish on this site

SeeTroubleshootingand submit aticket if you can't find an answer.

Please cite Horovod in your publications if it helps your research:

@article{sergeev2018horovod,
Author = {Alexander Sergeev and Mike Del Balso},
Journal = {arXiv preprint arXiv:1802.05799},
Title = {Horovod: fast and easy distributed deep learning in {TensorFlow}},
Year = {2018}
}

1. Sergeev, A., Del Balso, M. (2017)Meet Horovod: Uber’s Open Source Distributed Deep Learning Framework for TensorFlow. Retrieved fromhttps://eng.uber /horovod/

2. Sergeev, A. (2017)Horovod - Distributed TensorFlow Made Easy.Retrieved from https:// slideshare.net/AlexanderSergeev4/horovod-distributed-tensorflow-made-easy

3. Sergeev, A., Del Balso, M. (2018)Horovod: fast and easy distributed deep learning in TensorFlow.Retrieved from arXiv:1802.05799

The Horovod source code was based off the Baidutensorflow-allreduce repository written by Andrew Gibiansky and Joel Hestness. Their original work is described in the article Bringing HPC Techniques to Deep Learning.