inference这么慢原来是因为这个!
model.train()
/ model.eval()
with model.no_grad()
最近在验证一个想法,写了个其实也不怎么复杂的网络,但是一到eval阶段就OOM(最近无卡可用的悲哀),然后找到了同病相怜的朋友👇Running out of memory during evaluation in Pytorch
pytorc中 model.train()用于在训练阶段,model.eval()用在验证和测试阶段,区别是对于Dropout和Batch Normlization层的影响。
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
with torch.no_grad是指停止自动求导。三者的概念比较简单,具体在网络的的用法👇
1 |
|
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!