bugfix of pytorch out of memory
记录一次模型推理过程中显存不断增加直至Out-of-Memory的解决过程.
场景
当我在算法中加入了一个运算量不大的操作后, 模型推理过程直接从原来的显存占用不到11G直接爆了3090的显存.
显存增加定位
首先需要判断是什么地方造成了显存的增加. 这里推荐一个库Pytorch-Memory-Utils.该库可以通过侵入式的track方法记录两个track方法之间的显存变化. 其记录的格式如下
1 | At main.py line 18: <module> Total Tensor Used Memory:466.4 Mb Total Allocated Memory:466.4 Mb |
最终是发现在模型decoder部分导致了显存的激增, 并且随着推理的进行,不断累加.
解决
解决这里参考了Pytorch模型测试时显存一直上升导致爆显存, 原因应该是decoder中的梯度在一直计算并累积导致显存的不合理累加. 因此解决方案也比较直接:
1 | with torch.no_grad(): |
将decoder的过程用torch.no_grad()
wrap起来. bug就解决好了.