Skip to content
Go back

Lecture 2 - PyTorch, resource accounting

Edit page

1. Memory Accounting

Almost everything (parameters, gradients, activations, optimizer states) are stored as floating point numbers.

不同类型的 floating point numbers:

Intuition: When to use float32 and bf16?


2. Compute Accounting

2.1 Tensor on GPU

默认tensor store在 CPU 内存中,使用.to("cuda"), device = "cuda,放到 GPU memory中。


2.2 Tensor Operations

Tensor Storage

PyTorch tensors 实际上是 pointer 指向分配的 memory,有 metadata describe 怎么获取其中的 element。

stride(i): 在第ii维度,移动到下一个element,需要skip多少element

Tensor Slicing

很多操作只提供了一个 different view of the tensor,此时底层没有 copy,注意修改会影响其他的 tensor。

例如:

Note that some views are non-contiguous entries, which means that further views aren’t possible.

x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # @inspect x
y = x.transpose(1, 0) # @inspect y
assert not y.is_contiguous()
try:
	y.view(2, 3)
	assert False
except RuntimeError as e:
	assert "view size is not compatible with input tensor's size and stride" in str(e)

但是可以先 enforce a tensor to be contiguous:

y = x.transpose(1, 0).contiguous().view(2, 3)

注意此时底层发生了 copy

Tensor Elementwise

These operations apply some operation to each element of the tensor and return a (new) tensor of the same shape.

triu takes the upper triangular part of a matrix.

Tensor Matmul

matrix multiplication

==the bread and butter== = 赖以生存的根本 / 最核心、最基础、最稳定的收入或技能来源


2.3 Tensor Einops

Einops is a library for manipulating tensors where dimensions are named.

Jaxtyping Basics

jaxtyping is a library providing type annotations and runtime type-checking for:

  1. shape and dtype of JAX arrays; (Now also supports PyTorch, NumPy, MLX, and TensorFlow!)
  2. PyTrees.

2.4 Tensor Operations FLOPs

A FLOP (floating-point operation) is a basic operation like addition or multiplication.

Matrix multiplication 一次需要2BDK2 * B * D * K flops:每次multiply之后还要add

Interpretation:

Model FLOPs utilization (MFU)

Definition:

actualFLOP/spromisedFLOP/s\frac{actual FLOP/s}{promised FLOP/s}

Usually, MFU of >= 0.5 is quite good (and will be higher if matmuls dominate).


2.5 Gradients

Consider a weight w that connects an input unit i to an output unit j . For each example in the batch, the weight w generates exactly 6 FLOPs combined in the forward and backward pass:

  1. The unit i multiplies its output h(i) by w to send it to the unit j.
  2. The unit j adds the unit i’s contribution to its total input a(j).
  3. The unit j multiplies the incoming loss gradient dL/da(j) by w to send it back to the unit i.
  4. The unit i adds the unit j’s contribution to its total loss gradient dL/dh(i).
  5. The unit j multiplies its loss gradient dL/da(j) by the unit i’s output h(i) to compute the loss gradient dL/dw for the given example.
  6. (The sneakiest FLOP, IMHO) The weight w adds the contribution from step 5 to its loss gradient accumulator dL/dw that aggregates gradients for all examples.

Putting it togther:


3. Models

3.1 Module Parameters

Model parameters are stored in PyTorch as nn.Parameter objects. 自动求梯度

Parameter Initialization

TODO:两种初始化


3.2 Custom Model

注意model放到GPU,modle.to(device)


3.3 Randomness

Randomness shows up in many places: parameter initialization, dropout, data ordering, etc.

For reproducibility, we recommend you always pass in a different random seed for each use of randomness.

# Torch
seed = 0
torch.manual_seed(seed)

# NumPy
import numpy as np
np.random.seed(seed)

# Python
import random
random.seed(seed)

3.4 Data Loading

Don’t want to load the entire data into memory at once (LLaMA data is 2.8TB).

Use memmap to lazily load only the accessed parts into memory.


3.5 Optimizer

TODO:


3.6 Checkpointing

Training language models take a long time and certainly will certainly crash.

You don’t want to lose all your progress.During training, it is useful to periodically save your model and optimizer state to disk.


3.7 Mixed Precision Training

Choice of data type (float32, bfloat16, fp8) have tradeoffs:

Solution: use float32 by default, but use {bfloat16, fp8} when possible.

A concrete plan:

Mixed precision training Micikevicius+ 2017

Pytorch has an automatic mixed precision (AMP) library. https://pytorch.org/docs/stable/amp.html


Edit page
Share this post on:

Previous Post
Kubernetes Learning