A non-diagonal SSM RNN computed in parallel without requiring stabilization
4 months ago
- #deep-learning
- #RNN
- #PyTorch
- 实现了一个基于广义数量级(GOOMs)的非对角线性状态空间模型(SSM)的深度RNN
- 循环状态可在更大动态范围的实数值上波动,通过前缀扫描实现非对角循环的并行计算
- 模型以标准PyTorch nn.Module实现,由于PyTorch对复数张量的不完全支持而部分可编译
- 训练细节包括批量大小、优化器(AdamW)、学习率调度表,以及自然语言生成等任务的其他超参数
- 模型性能:在100亿token训练后交叉熵损失约2.7,与最先进模型相当
- 额外任务包括Sequential MNIST生成/分类、Wikitext-103和Copy-Memory任务
- 模型提供便捷方法如get_param_groups()、compute_loss_and_metrics()和generate()
- GOOMs以torch.complex64张量实现,使用自定义torch.Autograd.function确保梯度正确反向传播
- 训练中对浮点张量使用torch.float16自动转换,但不适用于复数GOOMs
- 该工作源于探索通过复平面映射实现非对角线性循环的并行计算