A non-diagonal SSM RNN computed in parallel without requiring stabilization
6 months ago
- #deep-learning
- #RNN
- #PyTorch
- Implementation of a deep RNN with non-diagonal linear state-space model (SSM) over generalized orders of magnitude (GOOMs).
- Recurrent states can fluctuate over a greater dynamic range of real values, enabling parallel computation of non-diagonal recurrences via prefix scan.
- Model is implemented as a standard PyTorch nn.Module, partially compilable due to PyTorch's incomplete support for complex tensors.
- Training details include batch size, optimizer (AdamW), learning rate schedule, and other hyper-parameters for tasks like natural language generation.
- Model performance: Cross-entropy loss of ~2.7 after training on 10B tokens, comparable to state-of-the-art models.
- Additional tasks include Sequential MNIST generation/classification, Wikitext-103, and Copy-Memory tasks.
- Model provides convenience methods like get_param_groups(), compute_loss_and_metrics(), and generate().
- GOOMs are implemented as torch.complex64 tensors, with custom torch.Autograd.function for proper gradient backpropagation.
- Training involves autocasting to torch.float16 for float tensors, but not for complex GOOMs.
- Work originated from exploring parallel computation of non-diagonal linear recurrences via complex plane mapping.