Skip to content

Latest commit

 

History

History
179 lines (142 loc) · 7.43 KB

File metadata and controls

179 lines (142 loc) · 7.43 KB
title Feature Overview
layout single
permalink /features/
toc true
toc_label Contents

Distributed Training with Mixed Precision

Mixed Precision Training

Enable 16-bit (FP16) training by in the deepspeed_config JSON.

"fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
}

Single-GPU, Multi-GPU, and Multi-Node Training

Easily switch between single-GPU, single-node multi-GPU, or multi-node multi-GPU execution by specifying resources with a hostfile.

deepspeed --hostfile=<hostfile> \
	<client_entry.py> <client args> \
	--deepspeed --deepspeed_config ds_config.json

The script <client_entry.py> will execute on the resources specified in <hostfile>.

Model Parallelism

Support for Custom Model Parallelism

DeepSpeed is supports all forms of model parallelism including tensor slicing based approaches such as the Megatron-LM, or a pipelined parallelism approach such as PipeDream or GPipe. It does so by only requiring the model parallelism framework to provide a model parallelism unit (mpu) that implements a few bookkeeping functionalities:

mpu.get_model_parallel_rank()
mpu.get_model_parallel_group()
mpu.get_model_parallel_world_size()

mpu.get_data_parallel_rank()
mpu.get_data_parallel_group()
mpu.get_data_parallel_world_size()

Integration with Megatron-LM

DeepSpeed is fully compatible with Megatron. Please see the Megatron-LM tutorial for details.

Memory and Bandwidth Optimizations

The Zero Redundancy Optimizer (ZeRO)

ZeRO is at the heart of DeepSpeed and enables large model training at a scale that is simply not possible with model parallelism alone. When enabled, ZeRO allows training models with over 6 billion parameters without any model parallelism, and up to 100 billion parameter models with model parallelism on current generation hardware.

For more details see the ZeRO paper, GPT tutorial on integration with DeepSpeed. Additional tutorials including BERT Tutorial: Coming Soon.

Constant Buffer Optimization (CBO)

CBO enables high network and memory throughput while restricting memory usage to a constant size. For memory- and network-bound operations such as normalization or allreduce collectives, the performance depends on the size of the operand. Simply fusing all operands into a single large operand can enable great throughput at the expense of unnecessary memory overhead. CBO in DeepSpeed fuses smaller operands into approximately a pre-defined sized buffer large enough to achieve great performance without the unnecessary memory overhead.

Smart Gradient Accumulation

Gradient accumulation allows running larger batch size with limited memory by breaking an effective batch into several sequential micro-batches, and averaging the parameter gradients across these micro-batches. Furthermore, instead of averaging the gradients of each micro-batch across all GPUs, the gradients are averaged locally during each step of the sequence, and a single allreduce is done at the end of the sequence to produce the averaged gradients for the effective batch across all GPUs. This strategy significantly reduces the communication involved over the approach of averaging globally for each micro-batch, specially when the number of micro-batches per effective batch is large.

Training Features

Simplified training API

The DeepSpeed core API consists of just a handful of methods:

  • initialization: initialize
  • training: backward and step
  • argument parsing: add_config_arguments
  • checkpointing : load_checkpoint and store_checkpoint

DeepSpeed supports all the features described in this document, via the use of these API, along with a deepspeed_config JSON file for enabling and disabling the features. Please see the core API doc for more details.

Gradient Clipping

DeepSpeed handles gradient clipping under the hood based on the max gradient norm specified by the user. Please see the core API doc for more details.

Automatic loss scaling with mixed precision

DeepSpeed internally handles loss scaling for mixed precision training. The parameters for loss scaling can be specified in the deepspeed_config JSON file. Please see the core API doc for more details.

Training Optimizers

Fused Adam optimizer and arbitrary torch.optim.Optimizer

With DeepSpeed, the user can choose to use a high performance implementation of ADAM from NVIDIA, or any training optimizer that extends torch's torch.optim.Optimizer class.

Memory bandwidth optimized FP16 Optimizer

Mixed precision training is handled by the DeepSpeed FP16 Optimizer. This optimizer not only handles FP16 training but is also highly efficient. The performance of weight update is primarily dominated by the memory bandwidth, and the achieved memory bandwidth is dependent on the size of the input operands. The FP16 Optimizer is designed to maximize the achievable memory bandwidth by merging all the parameters of the model into a single large buffer, and applying the weight updates in a single kernel, allowing it to achieve high memory bandwidth.

Large Batch Training with LAMB Optimizer

DeepSpeed makes it easy to train with large batch sizes by enabling the LAMB Optimizer. For more details on LAMB, see the LAMB paper.

Memory-Efficient Training with ZeRO Optimizer

DeepSpeed can train models up with up to 6 billion parameters without parallelism, and models with up to 100 billion parameters with 16-way model parallelism. This leap in model size is possible though the memory efficiency achieved via the ZeRO Optimizer. For more details see ZeRO paper .

Training Agnostic Checkpointing

DeepSpeed can simplify checkpointing for you regardless of whether you are using data parallel training, model parallel training, mixed-precision training, a mix of these three, or using the zero optimizer to enable larger model sizes. Please see the Getting Started guide and the Please see the core API doc for more details.

Advanced parameter search

DeepSpeed supports multiple Learning Rate Schedules to enable faster convergence for large batch scaling.

Learning Rate Range Test

Please refer to the Learning Rate Range Test tutorial.

1Cycle Learning Rate Schedule

Please refer to the 1Cycle Learning Rate Schedule tutorial.

Simplified Data Loader

DeepSpeed abstracts away data parallelism and model parallelism from the user when it comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed data loader can automatically handle batch creation appropriately.

Performance Analysis and Debugging

For performance debugging, DeepSpeed can give you a detailed breakdown of the time spent in different parts of the training with by simply enabling it in the deepspeed_config file. Please see the core API doc for more details.

{
  "wall_clock_breakdown": true
}