使用MultiEpochsDataLoader

此场景需增加源码并使能,请用户根据性能情况判断是否需要替换,若性能差异较大推荐且耗时问题出现在数据处理可使用此方案。

使用场景

DataLoader在每个epoch开始的时候都会重新创建一次,因此每个epocch开始所有的worker会重新开始prefetching过程,就会引起数据读取过程的耗时。可以通过使用MultiEpochsDataLoader以减少重新创建epoch造成的数据读取耗时。

操作步骤

MultiEpochsDataLoader相关源码如下,使用该部分代码替换原有dataloader,并在代码中调用此方法。

class MultiEpochsDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()
    def __len__(self):
        return len(self.batch_sampler.sampler)
    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)
class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """
    def __init__(self, sampler):
        self.sampler = sampler
    def __iter__(self):
        while True:
            yield from iter(self.sampler)

MultiEpochsDataLoader使用方法请参见开源社区