Computing Sharding with Einsum
13 days ago
- #einsum
- #sharding
- #tensor_operations
- Einsum notation simplifies matrix operations by directly representing input and output tensor shapes.
- Einsum can efficiently compute gradients by swapping input and output indices for backward passes.
- Sharding rules for einsum include cases for replicated, sharded batch, free, and contraction dimensions.
- Tensor parallelism example shows how sharding affects gradient computation, requiring all-reduce for Partial() results.
- Sequence parallel example demonstrates sharding on sequence dimension, leading to Partial() gradients needing all-reduce.