Mixed Precision Training
Key Points
- •Mixed precision training stores and computes tensors in low precision for speed and memory savings while keeping a master copy of weights in FP32 for accurate updates.
- •Forward and backward passes run mostly in FP16/BF16 with FP32 accumulation, but the optimizer updates FP32 master weights to avoid numerical drift.
- •Dynamic loss scaling multiplies the loss by a factor s to prevent gradient underflow in low precision, then divides gradients by s before the FP32 update.
- •BF16 is often preferred over FP16 because it shares FP32’s 8-bit exponent, reducing overflow/underflow while still saving memory.
- •The critical stability trick is “compute in wider type, store in narrower type”: accumulate dot products in FP32 even if inputs are FP16/BF16.
- •Using FP16/BF16 can nearly halve memory bandwidth and often increases throughput on modern GPUs with Tensor Cores.
- •You must detect inf/NaN in scaled gradients to adjust the loss scale, skipping updates on overflow and reducing s.
- •Mixed precision is beneficial in large models and input-bound workloads; it may not help much on very small models or on hardware without fast low-precision units.
Prerequisites
- →Floating-point formats (IEEE 754) — Understanding exponent/mantissa, rounding, and range clarifies why FP16/BF16 behave differently from FP32.
- →Matrix multiplication (GEMM) — Most training time is spent in GEMM/convolutions; knowing accumulation error explains the need for FP32 accumulators.
- →Backpropagation and gradients — You must know how gradients are computed to see where scaling and precision choices apply.
- →Stochastic gradient descent and optimizers — Mixed precision hinges on applying updates in FP32 master weights with correctly unscaled gradients.
- →Numerical stability and overflow/underflow — Loss scaling and FP32 accumulation are countermeasures to these issues in low precision.
Detailed Explanation
Tap terms for definitions01Overview
Hook: Imagine carrying groceries in smaller bags so you can move faster, but you keep the receipt in a safe place so numbers remain correct. Mixed precision training does this with numbers: small, fast data for speed, and a safe, accurate copy for correctness. Concept: Mixed precision training uses low-precision formats like FP16 or BF16 during forward and backward passes to reduce memory use and increase throughput, while maintaining an FP32 “master” copy of weights for numerically stable optimization. The hardware executes many more low-precision operations per second, and memory bandwidth per tensor is roughly halved. Meanwhile, the FP32 master parameters and FP32 accumulations preserve training stability and final model quality. Example: A linear layer multiplies an activation matrix by a weight matrix. In mixed precision, activations and weights are stored as FP16/BF16. The GPU multiplies them (often accumulating in FP32), producing outputs. Gradients are computed similarly. Before the optimizer step, gradients are unscaled and applied to the FP32 master weights. The low-precision weights are then refreshed from the master copy for the next iteration.
02Intuition & Analogies
Hook: Think of doing rough arithmetic on a scratchpad to move quickly, but keeping an exact ledger for important totals. You make fast approximations while working, then consult your exact ledger when committing results. Concept: FP16/BF16 are like the scratchpad: smaller, faster, and “good enough” most of the time. FP32 is the ledger: accurate and used when correctness matters most (the weight update). During matrix multiplications, we still add numbers using FP32 (like writing neat totals in your ledger), even though individual values are stored in a compact format. This limits the error that accumulates when summing many terms. Example: Suppose you’re summing many small prices. If you round after every addition, small cents might vanish. But if you keep cents during the sum and only round once at the end, totals are much more accurate. Mixed precision mirrors this: store FP16/BF16 to save space, but accumulate products in FP32 and only round when storing back. Loss scaling is like temporarily changing units (e.g., measuring in cents instead of dollars) so tiny gradients don’t vanish; you convert back before updating the ledger.
03Formal Definition
04When to Use
Hook: Use mixed precision when speed and memory are bottlenecks but you still need accurate learning dynamics. Concept: It’s ideal for deep networks where most runtime is in matrix multiplications or convolutions, and hardware supports fast FP16/BF16 (e.g., NVIDIA Tensor Cores, TPUs, modern CPUs with AMX). Use cases:
- Large transformer or CNN models where activations dominate memory; FP16/BF16 nearly halves activation memory, enabling larger batch sizes or models.
- Training regimes that are bandwidth-bound; 16-bit storage halves traffic and often saturates compute units better.
- Inference at low precision after training at mixed precision to reduce latency and memory footprint.
- BF16 preferred when gradients have wide dynamic range (since BF16 shares FP32’s exponent), or on hardware with native BF16. Not ideal when:
- Models are tiny (kernel launch or framework overhead dominates) or hardware lacks low-precision acceleration.
- Numerically brittle algorithms (e.g., some RNNs without normalization) where even BF16 may be risky without extra care.
⚠️Common Mistakes
Hook: Most failures come from mixing the wrong precisions in the wrong places or forgetting guardrails. Concept and pitfalls:
- Updating weights in FP16/BF16. Always keep a FP32 master copy for optimizer math; otherwise, rounding error compounds and training can diverge.
- No FP32 accumulation. Dot products summed in FP16/BF16 lose significant bits; insist on FP32 (or wider) accumulators for GEMMs and convolutions.
- Skipping loss scaling with FP16. FP16’s narrow exponent often underflows small gradients to zero. Use dynamic loss scaling or switch to BF16 when available.
- Failing to check for inf/NaN after scaling. If you don’t detect overflow and skip the step, you may inject NaNs into the FP32 master state.
- Inconsistent casting. Forgetting to refresh low-precision weights from the FP32 master after each step can desynchronize forward/backward from optimizer state.
- Casting optimizer states (Adam’s m, v) to FP16/BF16. Keep them in FP32; otherwise you negate Adam’s benefits. Example fixes: Maintain FP32 master weights and optimizer states, use FP32 accumulation, integrate dynamic loss scaling with overflow checks, and refresh BF16/FP16 copies each iteration.
Key Formulas
FP32 Master Update with Loss Scaling
Explanation: The optimizer applies gradients to FP32 master weights. Gradients are produced from a scaled loss and must be divided by the current scaling factor before the update.
Overflow-Responsive Scaling
Explanation: If any gradient becomes inf/NaN, we skip the step and reduce the scaling factor to prevent future overflows. Without overflow, we keep s unchanged (or increase on a schedule).
Safe-Step Doubling Rule
Explanation: After K safe steps, we can increase the loss scale to reduce underflow risk and possibly improve signal-to-quantization ratio.
Matrix Multiply (GEMM)
Explanation: Matrix multiplication sums K products per output element. Using FP32 accumulation for this sum greatly improves numerical accuracy when inputs are FP16/BF16.
Loss Scaling Identity
Explanation: Scaling the loss by s multiplies gradients by s during backprop. Dividing by s before the update restores the original gradient magnitude if no overflow occurred.
Memory Saving from 16-bit Storage
Explanation: Storing tensors in 16-bit formats roughly halves memory. Optimizer states kept in FP32 limit total savings but activations and gradients still benefit heavily.
BF16 Rounding (to-nearest-even)
Explanation: Converting float32 to bfloat16 can be implemented by rounding away the lower 16 mantissa bits, then reinterpreting as float. This keeps sign and exponent intact.
Empirical Risk
Explanation: The training loss is the average over per-sample losses. Mixed precision does not change the objective, only the numeric format during computation.
Complexity Analysis
Code Examples
1 #include <iostream> 2 #include <vector> 3 #include <random> 4 #include <cstdint> 5 #include <cstring> 6 #include <cmath> 7 #include <limits> 8 #include <algorithm> 9 10 // Minimal bfloat16 helper: store as 16-bit, convert to/from float with round-to-nearest-even 11 struct bf16 { 12 uint16_t bits; 13 bf16() : bits(0) {} 14 explicit bf16(float x) { bits = float_to_bf16_bits(x); } 15 static uint16_t float_to_bf16_bits(float x) { 16 // Round-to-nearest-even by adding 0x8000 and handling ties via the dropped LSB 17 uint32_t u; 18 static_assert(sizeof(u) == sizeof(x), "sizes must match"); 19 std::memcpy(&u, &x, sizeof(u)); 20 uint32_t lsb = (u >> 16) & 1u; // least-significant bit of the remaining mantissa 21 uint32_t rounding_bias = 0x8000u - 1u + lsb; // 0x7FFF + lsb ensures ties-to-even 22 u += rounding_bias; 23 uint16_t out = static_cast<uint16_t>(u >> 16); 24 return out; 25 } 26 static float bf16_bits_to_float(uint16_t b) { 27 uint32_t u = static_cast<uint32_t>(b) << 16; 28 float x; 29 std::memcpy(&x, &u, sizeof(x)); 30 return x; 31 } 32 float to_float() const { return bf16_bits_to_float(bits); } 33 }; 34 35 // Utility: check overflow (inf/NaN) in a float vector 36 bool has_overflow(const std::vector<float>& v) { 37 for (float x : v) { 38 if (!std::isfinite(x)) return true; 39 } 40 return false; 41 } 42 43 // A simple mixed-precision linear model: y = x^T w + b 44 struct MixedPrecisionLinear { 45 int D; // input dimension 46 std::vector<float> w32; // FP32 master weights 47 float b32; // FP32 master bias 48 std::vector<uint16_t> w16; // BF16 forward copy (bits) 49 50 explicit MixedPrecisionLinear(int D_) : D(D_), w32(D_, 0.0f), b32(0.0f), w16(D_, 0) {} 51 52 void refresh_low_precision_copy() { 53 for (int i = 0; i < D; ++i) w16[i] = bf16::float_to_bf16_bits(w32[i]); 54 } 55 56 // Forward pass: x comes as bf16 bits; multiply in float, accumulate in float 57 float forward(const std::vector<uint16_t>& x_bf16_bits) const { 58 float sum = 0.0f; 59 for (int i = 0; i < D; ++i) { 60 float xi = bf16::bf16_bits_to_float(x_bf16_bits[i]); 61 float wi = bf16::bf16_bits_to_float(w16[i]); 62 sum += xi * wi; // accumulate in FP32 for stability 63 } 64 return sum + b32; 65 } 66 67 // Compute gradients for one sample given dL/dy (grad_out). Returns (grad_w, grad_b) 68 void grad_one(const std::vector<uint16_t>& x_bf16_bits, float grad_out, std::vector<float>& grad_w_accum, float& grad_b_accum) const { 69 for (int i = 0; i < D; ++i) { 70 float xi = bf16::bf16_bits_to_float(x_bf16_bits[i]); 71 grad_w_accum[i] += grad_out * xi; // dL/dw_i = dL/dy * x_i 72 } 73 grad_b_accum += grad_out; // dL/db = dL/dy 74 } 75 76 // SGD update on FP32 master weights 77 void sgd_step(const std::vector<float>& grad_w, float grad_b, float lr) { 78 for (int i = 0; i < D; ++i) w32[i] -= lr * grad_w[i]; 79 b32 -= lr * grad_b; 80 } 81 }; 82 83 int main() { 84 // Problem setup: synthetic data for linear regression 85 const int N = 1024; // samples 86 const int D = 16; // features 87 const float lr = 1e-2f; 88 const int epochs = 200; 89 90 std::mt19937 rng(42); 91 std::normal_distribution<float> nd(0.0f, 1.0f); 92 93 // True weights for data generation (float32) 94 std::vector<float> w_true(D); 95 for (int i = 0; i < D; ++i) w_true[i] = nd(rng); 96 float b_true = nd(rng); 97 98 // Dataset stored in BF16 for forward pass 99 std::vector<std::vector<uint16_t>> X(N, std::vector<uint16_t>(D)); 100 std::vector<float> y(N); 101 for (int i = 0; i < N; ++i) { 102 float sum = 0.0f; 103 for (int j = 0; j < D; ++j) { 104 float xij = nd(rng); 105 X[i][j] = bf16::float_to_bf16_bits(xij); 106 sum += xij * w_true[j]; 107 } 108 y[i] = sum + b_true + 0.01f * nd(rng); // small noise 109 } 110 111 // Model 112 MixedPrecisionLinear model(D); 113 // Initialize master weights small 114 std::uniform_real_distribution<float> ud(-0.1f, 0.1f); 115 for (int i = 0; i < D; ++i) model.w32[i] = ud(rng); 116 model.b32 = 0.0f; 117 model.refresh_low_precision_copy(); 118 119 // Dynamic loss scaling 120 float loss_scale = 128.0f; // start moderately high 121 int safe_steps = 0; 122 const int safe_to_double = 50; // after this many safe steps, try increasing scale 123 124 for (int epoch = 1; epoch <= epochs; ++epoch) { 125 // Forward + backward over the whole dataset (full-batch for simplicity) 126 std::vector<float> grad_w_scaled(D, 0.0f); 127 float grad_b_scaled = 0.0f; 128 float mse = 0.0f; 129 130 for (int i = 0; i < N; ++i) { 131 float y_pred = model.forward(X[i]); // FP32 accumulation 132 float err = y_pred - y[i]; 133 mse += err * err; 134 // dL/dy for MSE with mean reduction: (2/N) * (y_pred - y) 135 float dldy = (2.0f / N) * err; 136 // Apply loss scaling: scale gradients as if loss were scaled 137 float dldy_scaled = dldy * loss_scale; 138 model.grad_one(X[i], dldy_scaled, grad_w_scaled, grad_b_scaled); 139 } 140 mse /= N; 141 142 // Overflow check on scaled gradients 143 bool overflow = has_overflow(grad_w_scaled) || !std::isfinite(grad_b_scaled); 144 if (overflow) { 145 loss_scale = std::max(1.0f, loss_scale / 2.0f); 146 safe_steps = 0; 147 std::cout << "Epoch " << epoch << ": overflow detected. Reducing loss_scale to " << loss_scale << " and skipping update. MSE=" << mse << "\n"; 148 // Skip update; do not alter weights; refresh low-precision copy anyway (unchanged) 149 model.refresh_low_precision_copy(); 150 continue; 151 } 152 153 // Unscale gradients before the optimizer step 154 std::vector<float> grad_w(D, 0.0f); 155 for (int i = 0; i < D; ++i) grad_w[i] = grad_w_scaled[i] / loss_scale; 156 float grad_b = grad_b_scaled / loss_scale; 157 158 // FP32 master update 159 model.sgd_step(grad_w, grad_b, lr); 160 // Refresh BF16 copy for the next forward/backward 161 model.refresh_low_precision_copy(); 162 163 // Adjust dynamic loss scale upward cautiously after many safe steps 164 safe_steps += 1; 165 if (safe_steps >= safe_to_double) { 166 loss_scale *= 2.0f; 167 safe_steps = 0; 168 } 169 170 if (epoch % 20 == 0 || epoch == 1) { 171 std::cout << "Epoch " << epoch << ": MSE=" << mse << ", loss_scale=" << loss_scale << "\n"; 172 } 173 } 174 175 // Report final error vs. ground-truth 176 std::vector<float> w_err(model.D); 177 float max_abs_err = 0.0f; 178 for (int i = 0; i < model.D; ++i) { 179 float e = std::abs(model.w32[i] - w_true[i]); 180 if (e > max_abs_err) max_abs_err = e; 181 } 182 std::cout << "Max abs weight error vs truth: " << max_abs_err << "\n"; 183 return 0; 184 } 185
This self-contained example demonstrates mixed precision training with BF16 for the forward/backward paths and FP32 master weights for updates. We define a minimal bfloat16 representation with round-to-nearest-even conversion. The dataset inputs are stored in BF16. The model keeps FP32 master weights and a BF16 copy for forward. Gradients are computed from a scaled loss (dynamic loss scaling). If scaled gradients overflow (inf/NaN), the step is skipped and the scale is reduced. Otherwise, gradients are unscaled and applied to the FP32 master weights via SGD, then the BF16 copy is refreshed. This mirrors practical mixed precision workflows in frameworks while remaining pure C++.
1 #include <iostream> 2 #include <vector> 3 #include <random> 4 #include <cstdint> 5 #include <cstring> 6 #include <cmath> 7 #include <algorithm> 8 9 struct bf16 { 10 uint16_t bits; 11 bf16() : bits(0) {} 12 explicit bf16(float x) { bits = float_to_bf16_bits(x); } 13 static uint16_t float_to_bf16_bits(float x) { 14 uint32_t u; std::memcpy(&u, &x, sizeof(u)); 15 uint32_t lsb = (u >> 16) & 1u; 16 uint32_t rounding_bias = 0x8000u - 1u + lsb; // ties-to-even 17 u += rounding_bias; 18 return static_cast<uint16_t>(u >> 16); 19 } 20 static float bf16_bits_to_float(uint16_t b) { 21 uint32_t u = static_cast<uint32_t>(b) << 16; 22 float x; std::memcpy(&x, &u, sizeof(x)); 23 return x; 24 } 25 }; 26 27 // C = A (N x K) * B (K x M), variants: accumulate in BF16 (quantize every add) vs FP32 28 void gemm_bf16_accum_bf16(const std::vector<uint16_t>& A, const std::vector<uint16_t>& B, 29 std::vector<uint16_t>& C, int N, int K, int M) { 30 for (int i = 0; i < N; ++i) { 31 for (int j = 0; j < M; ++j) { 32 float sum = 0.0f; 33 uint16_t sum_b = bf16::float_to_bf16_bits(0.0f); 34 for (int k = 0; k < K; ++k) { 35 float a = bf16::bf16_bits_to_float(A[i*K + k]); 36 float b = bf16::bf16_bits_to_float(B[k*M + j]); 37 float prod = a * b; 38 // Accumulate in BF16 by re-quantizing after each addition 39 float sum_f = bf16::bf16_bits_to_float(sum_b) + prod; 40 sum_b = bf16::float_to_bf16_bits(sum_f); 41 } 42 C[i*M + j] = sum_b; // store BF16 result 43 } 44 } 45 } 46 47 void gemm_bf16_accum_fp32(const std::vector<uint16_t>& A, const std::vector<uint16_t>& B, 48 std::vector<uint16_t>& C, int N, int K, int M) { 49 for (int i = 0; i < N; ++i) { 50 for (int j = 0; j < M; ++j) { 51 float sum = 0.0f; // FP32 accumulator 52 for (int k = 0; k < K; ++k) { 53 float a = bf16::bf16_bits_to_float(A[i*K + k]); 54 float b = bf16::bf16_bits_to_float(B[k*M + j]); 55 sum += a * b; 56 } 57 C[i*M + j] = bf16::float_to_bf16_bits(sum); // quantize once at the end 58 } 59 } 60 } 61 62 void gemm_float32_ref(const std::vector<float>& A, const std::vector<float>& B, 63 std::vector<float>& C, int N, int K, int M) { 64 for (int i = 0; i < N; ++i) { 65 for (int j = 0; j < M; ++j) { 66 float sum = 0.0f; 67 for (int k = 0; k < K; ++k) sum += A[i*K + k] * B[k*M + j]; 68 C[i*M + j] = sum; 69 } 70 } 71 } 72 73 float relative_error(const std::vector<float>& ref, const std::vector<uint16_t>& test_bf16) { 74 double num = 0.0, den = 0.0; 75 for (size_t i = 0; i < ref.size(); ++i) { 76 float test = bf16::bf16_bits_to_float(test_bf16[i]); 77 double r = ref[i]; 78 double diff = std::abs(test - r); 79 num += diff; 80 den += std::max(1e-12, std::abs(r)); 81 } 82 return static_cast<float>(num / den); 83 } 84 85 int main() { 86 int N = 64, K = 128, M = 64; 87 std::mt19937 rng(0); 88 std::normal_distribution<float> nd(0.0f, 1.0f); 89 90 std::vector<float> Af(N*K), Bf(K*M), Cref(N*M); 91 for (auto &x : Af) x = nd(rng) * 0.1f; // scale inputs to avoid trivial overflow 92 for (auto &x : Bf) x = nd(rng) * 0.1f; 93 94 // Convert to BF16 storage 95 std::vector<uint16_t> A(N*K), B(K*M), C1(N*M), C2(N*M); 96 for (int i = 0; i < N*K; ++i) A[i] = bf16::float_to_bf16_bits(Af[i]); 97 for (int i = 0; i < K*M; ++i) B[i] = bf16::float_to_bf16_bits(Bf[i]); 98 99 gemm_float32_ref(Af, Bf, Cref, N, K, M); 100 gemm_bf16_accum_bf16(A, B, C1, N, K, M); 101 gemm_bf16_accum_fp32(A, B, C2, N, K, M); 102 103 float err_bf16acc = relative_error(Cref, C1); 104 float err_fp32acc = relative_error(Cref, C2); 105 106 std::cout << "Relative L1 error vs FP32 ref:\n"; 107 std::cout << " Accumulate in BF16: " << err_bf16acc << "\n"; 108 std::cout << " Accumulate in FP32: " << err_fp32acc << "\n"; 109 return 0; 110 } 111
This program simulates two ways to multiply BF16 matrices: one that re-quantizes the running sum to BF16 after each addition (poor accuracy) and one that accumulates in FP32 and only quantizes once at the end (good accuracy). It reports relative error versus a pure FP32 reference GEMM. The result highlights why mixed precision uses low-precision inputs but FP32 accumulation during forward/backward passes.