Scaling RNNs to Billions of Parameters with Zero Order
a year ago
- #Zero-Order Optimization
- #Machine Learning
- #Recurrent Neural Networks
- Recurrent Neural Networks (RNNs) scale constant in FLOPs and GPU memory during inference, unlike transformers which scale linearly.
- Training large RNNs on long contexts is impractical with Backpropagation Through Time (BPTT) due to high memory usage.
- Zero-Order Optimization (ZOO) methods like Random-vector Gradient Estimation (RGE) can replace BPTT, reducing memory usage and cost while matching or exceeding BPTT convergence rates.
- Central-Difference RGE (CD-RGE) optimizes a smoothed surrogate loss, improving regularization and generalization.
- The method matches or outperforms BPTT in overfitting, transduction, and language modeling tasks, often with fewer steps.
- Despite requiring more forward passes per step, the method can surpass BPTT wall-clock time using advancements like FlashRNN and distributed inference.