1. Memory Accounting
Almost everything (parameters, gradients, activations, optimizer states) are stored as floating point numbers.
不同类型的 floating point numbers:
- float32 (fp32, single precision)
- float16 (fp16, half precision)
- bfloat16: 和fp16使用相同的memory,但是通过牺牲精度(减少fraction bits),来提高可表示的范围(增加 exponent bits)
- fp8
Intuition: When to use float32 and bf16?
- float32: Basically for parameters and optimizers. Accumulate over time needing higher precision.
- bf16: Transitory. Take parameters and cast them into bf16. Run ahead.
2. Compute Accounting
2.1 Tensor on GPU
默认tensor store在 CPU 内存中,使用.to("cuda"), device = "cuda,放到 GPU memory中。
2.2 Tensor Operations
Tensor Storage
PyTorch tensors 实际上是 pointer 指向分配的 memory,有 metadata describe 怎么获取其中的 element。
stride(i): 在第维度,移动到下一个element,需要skip多少element
Tensor Slicing
很多操作只提供了一个 different view of the tensor,此时底层没有 copy,注意修改会影响其他的 tensor。
例如:
- Get row / column, 下标索引
- View matrix as 2 matrix:
y = x.view(3, 2) - Transpose:
y = x.transpose(1, 0)
Note that some views are non-contiguous entries, which means that further views aren’t possible.
x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # @inspect x
y = x.transpose(1, 0) # @inspect y
assert not y.is_contiguous()
try:
y.view(2, 3)
assert False
except RuntimeError as e:
assert "view size is not compatible with input tensor's size and stride" in str(e)
但是可以先 enforce a tensor to be contiguous:
y = x.transpose(1, 0).contiguous().view(2, 3)
注意此时底层发生了 copy
Tensor Elementwise
These operations apply some operation to each element of the tensor and return a (new) tensor of the same shape.
triutakes the upper triangular part of a matrix.
Tensor Matmul
matrix multiplication
==the bread and butter== = 赖以生存的根本 / 最核心、最基础、最稳定的收入或技能来源
2.3 Tensor Einops
Einops is a library for manipulating tensors where dimensions are named.
Jaxtyping Basics
jaxtyping is a library providing type annotations and runtime type-checking for:
- TODO: 组合技实战
2.4 Tensor Operations FLOPs
A FLOP (floating-point operation) is a basic operation like addition or multiplication.
Matrix multiplication 一次需要 flops:每次multiply之后还要add
Interpretation:
- B is the number of data points
- (D K) is the number of parameters
- FLOPs for forward pass is
Model FLOPs utilization (MFU)
Definition:
Usually, MFU of >= 0.5 is quite good (and will be higher if matmuls dominate).
2.5 Gradients
Consider a weight w that connects an input unit i to an output unit j . For each example in the batch, the weight w generates exactly 6 FLOPs combined in the forward and backward pass:
- The unit i multiplies its output h(i) by w to send it to the unit j.
- The unit j adds the unit i’s contribution to its total input a(j).
- The unit j multiplies the incoming loss gradient dL/da(j) by w to send it back to the unit i.
- The unit i adds the unit j’s contribution to its total loss gradient dL/dh(i).
- The unit j multiplies its loss gradient dL/da(j) by the unit i’s output h(i) to compute the loss gradient dL/dw for the given example.
- (The sneakiest FLOP, IMHO) The weight w adds the contribution from step 5 to its loss gradient accumulator dL/dw that aggregates gradients for all examples.
Putting it togther:
- Forward pass: 2 (# data points) (# parameters) FLOPs
- Backward pass: 4 (# data points) (# parameters) FLOPs
- Total: 6 (# data points) (# parameters) FLOPs
3. Models
3.1 Module Parameters
Model parameters are stored in PyTorch as nn.Parameter objects. 自动求梯度
Parameter Initialization
TODO:两种初始化
- Xavier 初始化 paper
- Kaiming He 初始化
3.2 Custom Model
注意model放到GPU,modle.to(device)
3.3 Randomness
Randomness shows up in many places: parameter initialization, dropout, data ordering, etc.
For reproducibility, we recommend you always pass in a different random seed for each use of randomness.
# Torch
seed = 0
torch.manual_seed(seed)
# NumPy
import numpy as np
np.random.seed(seed)
# Python
import random
random.seed(seed)
3.4 Data Loading
Don’t want to load the entire data into memory at once (LLaMA data is 2.8TB).
Use memmap to lazily load only the accessed parts into memory.
3.5 Optimizer
TODO:
- momentum = SGD + exponential averaging of grad
- AdaGrad = SGD + averaging by
- RMSProp = AdaGrad + exponentially averaging of
- Adam = RMSProp + momentum
3.6 Checkpointing
Training language models take a long time and certainly will certainly crash.
You don’t want to lose all your progress.During training, it is useful to periodically save your model and optimizer state to disk.
3.7 Mixed Precision Training
Choice of data type (float32, bfloat16, fp8) have tradeoffs:
- Higher precision: more accurate/stable, more memory, more compute
- Lower precision: less accurate/stable, less memory, less compute
Solution: use float32 by default, but use {bfloat16, fp8} when possible.
A concrete plan:
- Use {bfloat16, fp8} for the forward pass (activations).
- Use float32 for the rest (parameters, gradients).
Mixed precision training Micikevicius+ 2017
Pytorch has an automatic mixed precision (AMP) library. https://pytorch.org/docs/stable/amp.html