inference这么慢原来是因为这个!

model.train() / model.eval() with model.no_grad()

最近在验证一个想法,写了个其实也不怎么复杂的网络,但是一到eval阶段就OOM(最近无卡可用的悲哀),然后找到了同病相怜的朋友👇
Running out of memory during evaluation in Pytorch

pytorc中 model.train()用于在训练阶段,model.eval()用在验证和测试阶段,区别是对于Dropout和Batch Normlization层的影响。
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
with torch.no_grad是指停止自动求导。三者的概念比较简单,具体在网络的的用法👇

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 训练阶段
for epoch in range(max_epoch):
model.train()
for i, data in enumerate(dataloader):
# 假设包含有 doc,label数据
# 因为doc ,labels是输入数据,可以使用with torch.no_grad()停止对他们的求导(❗️就是这,加上了我就不oom了,因为输入数据实在是太多了TAT
# 当然不用也非常OK,不过在我这种输入数据很大的情况下可以明显提升速度和减少显存占用
with torch.no_grad():
inputs_ids = data[0]
labels = data[1]
# 测试阶段
model.eval()
with torch.no_grad():
....
# 在测试阶段使用with torch.no_grad()可以对整个网络都停止自动求导,可以大大大大加快速度,也可以使用大的batch_size来测试(当然不用也可以啦

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!