20x less peak RAM in the new PyTorch memory budget solver
15 hours ago
- #memory-optimization
- #PyTorch
- #knapsack-problem
- PyTorch uses a knapsack solver for memory planning during computation graph building.
- Default solver is `dp_knapsack`, a dynamic programming approach with high memory usage.
- New solver `dp_knapsack_sliding_hirschberg` reduces peak RAM usage by 20x and speeds up runtime by ~37%.
- Hirschberg's algorithm eliminates backtracking by dividing the problem recursively.
- Sliding window optimization reduces DP table size from (items × max_weight) to (2 × max_weight).
- For speed over correctness, use `greedy_knapsack`; for exact solutions with SciPy, use `ilp_knapsack`.
- `dp_knapsack_sliding_hirschberg` is currently only available in PyTorch's main branch (unreleased).
- Early adopters can enable it by setting `activation_memory_budget_solver` to `dp_knapsack_sliding_hirschberg`.