In this brief post, I discuss some of the trends of ML and list some of the notable recent works.
The way we train SotA models is slightly different from a few years ago for the purpose of optimizing the performance:
- We would first build a massive (often multimodal) dataset crawled from Web and model-parallelize your model with techniques from DeepSpeed, GSPMD, etc.
- We would then scale the following variables based on the compute budget according to the existing scaling laws or based on our exploration at smaller scale:
- Width of each layer
- Depth of model
- Batch size
- Number of iterations
- Learning rate
Sutiable model design obviously depends on the problem setting. The modern choice is typically as follows:
- Decoder
- Text output
- Transformer decoder
- Non-text output (e.g. images)
- Diffusion models (optionally with classifier guidance)
- Text output
- Encoder
- Text input
- Transformer encoder
- Non-text input
- ViT variants
- Text input
Notably, diffusion models are beginning to dominate the leaderboard over GAN, VAE and MLE models in each non-text modality, and contrastive learning is dominating on representation learning. Multimodal models (e.g. CLIP, DALL-E) are also dominating in various domains.
If you have an ample supply of GPUs and if you are either writing your code from scratch or based on an existing PyTorch code, you may want to use PyTorch. Otherwise, you may want to use TPUs (and therefore JAX) by getting larger pods from TRC. Thankfully, compared with GPUs, it is much easier to get a large amount of computes from TPUs by applying for TRC (which is usually accepted, but there is no guarantee that this trend will continue). For many people, this is a good reason to learn JAX. In addition to JAX, it is generally recommended to use Flax or Haiku for model building.
Some notable recent works
Scaling laws:
Scaling tricks:
- DeepSpeed (scaling tricks for PyTorch)
- GSPMD (scaling tricks for JAX, e.g., pjit)
- WebDataset (recommended data pipeline for large-scale training on PyTorch)
- MoE / Switch Transformer
Massive datasets:
- The Pile (massive text dataset)
- LAION-400M (massive image-text dataset)
LMs:
- GPT-3
- FLAN
- T5 / BART
- GLM
- Measuring Massive Multitask Language Understanding
- Hurdles to Progress in Long-form Question Answering
Diffusion models:
- Denoising Diffusion Probabilistic Models
- Improved Denoising Diffusion Probabilistic Models
- Diffusion Models Beat GANs on Image Synthesis
Multimodal models:
VAE & GAN:
RL:
- Dreamerv2
- Decision Transformer / Trajectory Transformer
- DrQ-v2
- Pretraining Representations for Data-Efficient Reinforcement Learning
JAX:
Misc: