Hasty Briefsbeta

双语

A non-diagonal SSM RNN computed in parallel without requiring stabilization

4 months ago
  • #deep-learning
  • #RNN
  • #PyTorch
  • 实现了一个基于广义数量级(GOOMs)的非对角线性状态空间模型(SSM)的深度RNN
  • 循环状态可在更大动态范围的实数值上波动,通过前缀扫描实现非对角循环的并行计算
  • 模型以标准PyTorch nn.Module实现,由于PyTorch对复数张量的不完全支持而部分可编译
  • 训练细节包括批量大小、优化器(AdamW)、学习率调度表,以及自然语言生成等任务的其他超参数
  • 模型性能:在100亿token训练后交叉熵损失约2.7,与最先进模型相当
  • 额外任务包括Sequential MNIST生成/分类、Wikitext-103和Copy-Memory任务
  • 模型提供便捷方法如get_param_groups()、compute_loss_and_metrics()和generate()
  • GOOMs以torch.complex64张量实现,使用自定义torch.Autograd.function确保梯度正确反向传播
  • 训练中对浮点张量使用torch.float16自动转换,但不适用于复数GOOMs
  • 该工作源于探索通过复平面映射实现非对角线性循环的并行计算