mmlearn.datasets.core.samplers.DistributedEvalSampler¶
- class DistributedEvalSampler(dataset, num_replicas=None, rank=None, shuffle=False, seed=0)[source]¶
-
Sampler for distributed evaluation.
The main differences between this and
torch.utils.data.DistributedSampler
are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.- Parameters:
dataset (torch.utils.data.Dataset) – Dataset used for sampling.
num_replicas (Optional[int], optional, default=None) – Number of processes participating in distributed training. By default,
rank
is retrieved from the current distributed group.rank (Optional[int], optional, default=None) – Rank of the current process within
num_replicas
. By default,rank
is retrieved from the current distributed group.shuffle (bool, optional, default=False) – If True (default), sampler will shuffle the indices.
seed (int, optional, default=0) – Random seed used to shuffle the sampler if
shuffle=True
. This number should be identical across all processes in the distributed group.
Warning
DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1] for details
Notes
This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with
torch.nn.parallel.DistributedDataParallel
[2].The input Dataset is assumed to be of constant size.
This implementation is adapted from [3].
References
Examples
>>> def example(): ... start_epoch, n_epochs = 0, 2 ... sampler = DistributedEvalSampler(dataset) if is_distributed else None ... loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler) ... for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... evaluate(loader)
Methods
Attributes