Sampler 获取 item 的 index, Dataset 根据 index 取样本。
Dataset
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
T_dict = Dict[str, T_co]
T_tuple = Tuple[T_co, ...]
T_stack = TypeVar("T_stack", T_tuple, T_dict)
[docs]class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co: # get item by index
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
# def __getitems__(self, indices: List) -> List[T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
ConcatDataset
[docs]class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = list(datasets) # list of Dataset
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets) # cumsum of all datasets'size
def __len__(self): # total size of all datasets
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
# idx in which datasets
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
# idx of `dataset_idx` dataset
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
@deprecated(
"`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
category=FutureWarning,
)
def cummulative_sizes(self):
return self.cumulative_sizes
IterableDataset
[docs]class IterableDataset(Dataset[T_co], Iterable[T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
>>> # xdoctest: +REQUIRES(POSIX)
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
>>> # With even more workers
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""
def __add__(self, other: Dataset[T_co]): # overload __add__
return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
ChianDataset
[docs]class ChainDataset(IterableDataset):
r"""Dataset for chaining multiple :class:`IterableDataset` s.
This class is useful to assemble different existing dataset streams. The
chaining operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.
Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = datasets # list of Dataset
def __iter__(self):
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
yield from d # yield item from cur IterableDataset
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
total += len(d) # type: ignore[arg-type]
return total
IterableDatasetShard
class IterableDatasetShard(IterableDataset):
"""
Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
`split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
`drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
be too small or loop with indices from the beginning.
Args:
dataset (`torch.utils.data.dataset.IterableDataset`):
The batch sampler to split in several shards.
batch_size (`int`, *optional*, defaults to 1):
The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
`split_batches=True`).
drop_last (`bool`, *optional*, defaults to `False`):
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
beginning.
num_processes (`int`, *optional*, defaults to 1):
The number of processes running concurrently.
process_index (`int`, *optional*, defaults to 0):
The index of the current process.
split_batches (`bool`, *optional*, defaults to `False`):
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
yielding different full batches on each process.
On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
- the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
argument is set to `False`.
- the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
this argument is set to `True`.
"""
def __init__(
self,
dataset: IterableDataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
split_batches: bool = False,
):
if split_batches and batch_size > 1 and batch_size % num_processes != 0:
raise ValueError(
f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
f"needs to be a round multiple of the number of processes ({num_processes})."
)
self.dataset = dataset # iterable of single item, not bached items
self.batch_size = batch_size # batch size of per shard
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.split_batches = split_batches
def set_epoch(self, epoch): # change seed with epoch
self.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def __len__(self):
# We will just raise the downstream error if the underlying dataset is not sized
# len(self.dataset) = total items number of dataset
if self.drop_last:
# number of batches = len(self.dataset) // (self.batch_size * self.num_processes)
# number of examples per shard = `number of batches` * self.batch_size
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
else:
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
def __iter__(self):
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
and isinstance(self.dataset.generator, torch.Generator)
):
self.dataset.generator.manual_seed(self.epoch)
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes) # real batch size
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size # batch sizer per process
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size) # index of cur process
first_batch = None
current_batch = []
for element in self.dataset: # yeild single item from self.dataset
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size: # get real batches
for i in process_slice:
# yeid batch items of cur process one by one
yield current_batch[i]
if first_batch is None:
first_batch = current_batch.copy()
current_batch = []
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
if not self.drop_last and len(current_batch) > 0:
if first_batch is None:
first_batch = current_batch.copy()
while len(current_batch) < real_batch_size:
current_batch += first_batch
for i in process_slice:
yield current_batch[i]
