此场景需增加源码并使能,请用户根据性能情况判断是否需要替换,若性能差异较大推荐且耗时问题出现在数据处理可使用此方案。
DataLoader在每个epoch开始的时候都会重新创建一次,因此每个epocch开始所有的worker会重新开始prefetching过程,就会引起数据读取过程的耗时。可以通过使用MultiEpochsDataLoader以减少重新创建epoch造成的数据读取耗时。
MultiEpochsDataLoader相关源码如下,使用该部分代码替换原有dataloader,并在代码中调用此方法。
class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler)
MultiEpochsDataLoader使用方法请参见开源社区。