在使用.to()指定训练设备时,建议使用non_blocking=True加速数据下沉来提升训练性能。开启non_blocking=True,可以使用数据下沉加速,减少数据在device侧和host侧转换的时间。
样例脚本优化前如下。
for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) x, y = batch loss = model(x, y)
优化后,开启non_blocking=True。
for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device, non_blocking=True) for t in batch) x, y = batch loss = model(x, y)