Hugging Face Accelerate IterableDatasetShard

Sampler 获取 item 的 index, Dataset 根据 index 取样本。

Dataset

  
T_co = TypeVar("T_co", covariant=True)  
T = TypeVar("T")  
T_dict = Dict[str, T_co]  
T_tuple = Tuple[T_co, ...]  
T_stack = TypeVar("T_stack", T_tuple, T_dict)  
  
  
[docs]class Dataset(Generic[T_co]):  
 r"""An abstract class representing a :class:`Dataset`.  
  
 All datasets that represent a map from keys to data samples should subclass  
 it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a  
 data sample for a given key. Subclasses could also optionally overwrite  
 :meth:`__len__`, which is expected to return the size of the dataset by many  
 :class:`~torch.utils.data.Sampler` implementations and the default options  
 of :class:`~torch.utils.data.DataLoader`. Subclasses could also  
 optionally implement :meth:`__getitems__`, for speedup batched samples  
 loading. This method accepts list of indices of samples of batch and returns  
 list of samples.  
  
 .. note::  
 :class:`~torch.utils.data.DataLoader` by default constructs an index  
 sampler that yields integral indices. To make it work with a map-style  
 dataset with non-integral indices/keys, a custom sampler must be provided.  
 """  
  
 def __getitem__(self, index) -> T_co: # get item by index  
 raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")  
  
 # def __getitems__(self, indices: List) -> List[T_co]:  
 # Not implemented to prevent false-positives in fetcher check in  
 # torch.utils.data._utils.fetch._MapDatasetFetcher  
  
 def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":  
 return ConcatDataset([self, other])  
  
  
 # No `def __len__(self)` default?  
 # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]  
 # in pytorch/torch/utils/data/sampler.py

ConcatDataset

  
[docs]class ConcatDataset(Dataset[T_co]):  
 r"""Dataset as a concatenation of multiple datasets.  
  
 This class is useful to assemble different existing datasets.  
  
 Args:  
 datasets (sequence): List of datasets to be concatenated  
 """  
  
 datasets: List[Dataset[T_co]]  
 cumulative_sizes: List[int]  
  
 @staticmethod  
 def cumsum(sequence):  
 r, s = [], 0  
 for e in sequence:  
 l = len(e)  
 r.append(l + s)  
 s += l  
 return r  
  
 def __init__(self, datasets: Iterable[Dataset]) -> None:  
 super().__init__()  
 self.datasets = list(datasets) # list of Dataset  
 assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]  
 for d in self.datasets:  
 assert not isinstance(  
 d, IterableDataset  
 ), "ConcatDataset does not support IterableDataset"  
 self.cumulative_sizes = self.cumsum(self.datasets) # cumsum of all datasets'size  
  
 def __len__(self): # total size of all datasets  
 return self.cumulative_sizes[-1]  
  
 def __getitem__(self, idx):  
 if idx < 0:  
 if -idx > len(self):  
 raise ValueError(  
 "absolute value of index should not exceed dataset length"  
 )  
 idx = len(self) + idx  
 # idx in which datasets  
 dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)  
 if dataset_idx == 0:  
 sample_idx = idx  
 else:  
 # idx of `dataset_idx` dataset  
 sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]  
 return self.datasets[dataset_idx][sample_idx]  
  
 @property  
 @deprecated(  
 "`cummulative_sizes` attribute is renamed to `cumulative_sizes`",  
 category=FutureWarning,  
 )  
 def cummulative_sizes(self):  
 return self.cumulative_sizes

IterableDataset

  
[docs]class IterableDataset(Dataset[T_co], Iterable[T_co]):  
 r"""An iterable Dataset.  
  
 All datasets that represent an iterable of data samples should subclass it.  
 Such form of datasets is particularly useful when data come from a stream.  
  
 All subclasses should overwrite :meth:`__iter__`, which would return an  
 iterator of samples in this dataset.  
  
 When a subclass is used with :class:`~torch.utils.data.DataLoader`, each  
 item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`  
 iterator. When :attr:`num_workers > 0`, each worker process will have a  
 different copy of the dataset object, so it is often desired to configure  
 each copy independently to avoid having duplicate data returned from the  
 workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker  
 process, returns information about the worker. It can be used in either the  
 dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's  
 :attr:`worker_init_fn` option to modify each copy's behavior.  
  
 Example 1: splitting workload across all workers in :meth:`__iter__`::  
  
 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)  
 >>> # xdoctest: +SKIP("Fails on MacOS12")  
 >>> class MyIterableDataset(torch.utils.data.IterableDataset):  
 ... def __init__(self, start, end):  
 ... super(MyIterableDataset).__init__()  
 ... assert end > start, "this example code only works with end >= start"  
 ... self.start = start  
 ... self.end = end  
 ...  
 ... def __iter__(self):  
 ... worker_info = torch.utils.data.get_worker_info()  
 ... if worker_info is None: # single-process data loading, return the full iterator  
 ... iter_start = self.start  
 ... iter_end = self.end  
 ... else: # in a worker process  
 ... # split workload  
 ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))  
 ... worker_id = worker_info.id  
 ... iter_start = self.start + worker_id * per_worker  
 ... iter_end = min(iter_start + per_worker, self.end)  
 ... return iter(range(iter_start, iter_end))  
 ...  
 >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].  
 >>> ds = MyIterableDataset(start=3, end=7)  
  
 >>> # Single-process loading  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))  
 [tensor([3]), tensor([4]), tensor([5]), tensor([6])]  
  
 >>> # xdoctest: +REQUIRES(POSIX)  
 >>> # Mult-process loading with two worker processes  
 >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].  
 >>> # xdoctest: +IGNORE_WANT("non deterministic")  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))  
 [tensor([3]), tensor([5]), tensor([4]), tensor([6])]  
  
 >>> # With even more workers  
 >>> # xdoctest: +IGNORE_WANT("non deterministic")  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))  
 [tensor([3]), tensor([5]), tensor([4]), tensor([6])]  
  
 Example 2: splitting workload across all workers using :attr:`worker_init_fn`::  
  
 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)  
 >>> class MyIterableDataset(torch.utils.data.IterableDataset):  
 ... def __init__(self, start, end):  
 ... super(MyIterableDataset).__init__()  
 ... assert end > start, "this example code only works with end >= start"  
 ... self.start = start  
 ... self.end = end  
 ...  
 ... def __iter__(self):  
 ... return iter(range(self.start, self.end))  
 ...  
 >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].  
 >>> ds = MyIterableDataset(start=3, end=7)  
  
 >>> # Single-process loading  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))  
 [3, 4, 5, 6]  
 >>>  
 >>> # Directly doing multi-process loading yields duplicate data  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))  
 [3, 3, 4, 4, 5, 5, 6, 6]  
  
 >>> # Define a `worker_init_fn` that configures each dataset copy differently  
 >>> def worker_init_fn(worker_id):  
 ... worker_info = torch.utils.data.get_worker_info()  
 ... dataset = worker_info.dataset # the dataset copy in this worker process  
 ... overall_start = dataset.start  
 ... overall_end = dataset.end  
 ... # configure the dataset to only process the split workload  
 ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))  
 ... worker_id = worker_info.id  
 ... dataset.start = overall_start + worker_id * per_worker  
 ... dataset.end = min(dataset.start + per_worker, overall_end)  
 ...  
  
 >>> # Mult-process loading with the custom `worker_init_fn`  
 >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))  
 [3, 5, 4, 6]  
  
 >>> # With even more workers  
 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))  
 [3, 4, 5, 6]  
 """  
  
 def __add__(self, other: Dataset[T_co]): # overload __add__  
 return ChainDataset([self, other])  
  
  
 # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.  
 # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]

ChianDataset

  
[docs]class ChainDataset(IterableDataset):  
 r"""Dataset for chaining multiple :class:`IterableDataset` s.  
  
 This class is useful to assemble different existing dataset streams. The  
 chaining operation is done on-the-fly, so concatenating large-scale  
 datasets with this class will be efficient.  
  
 Args:  
 datasets (iterable of IterableDataset): datasets to be chained together  
 """  
  
 def __init__(self, datasets: Iterable[Dataset]) -> None:  
 super().__init__()  
 self.datasets = datasets # list of Dataset  
  
 def __iter__(self):  
 for d in self.datasets:  
 assert isinstance(  
 d, IterableDataset  
 ), "ChainDataset only supports IterableDataset"  
 yield from d # yield item from cur IterableDataset  
  
 def __len__(self):  
 total = 0  
 for d in self.datasets:  
 assert isinstance(  
 d, IterableDataset  
 ), "ChainDataset only supports IterableDataset"  
 total += len(d) # type: ignore[arg-type]  
 return total

IterableDatasetShard

  
class IterableDatasetShard(IterableDataset):  
 """  
 Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will  
 always yield a number of samples that is a round multiple of the actual batch size (depending of the value of  
 `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the  
 `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would  
 be too small or loop with indices from the beginning.  
  
 Args:  
 dataset (`torch.utils.data.dataset.IterableDataset`):  
 The batch sampler to split in several shards.  
 batch_size (`int`, *optional*, defaults to 1):  
 The size of the batches per shard (if `split_batches=False`) or the size of the batches (if  
 `split_batches=True`).  
 drop_last (`bool`, *optional*, defaults to `False`):  
 Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the  
 beginning.  
 num_processes (`int`, *optional*, defaults to 1):  
 The number of processes running concurrently.  
 process_index (`int`, *optional*, defaults to 0):  
 The index of the current process.  
 split_batches (`bool`, *optional*, defaults to `False`):  
 Whether the shards should be created by splitting a batch to give a piece of it on each process, or by  
 yielding different full batches on each process.  
  
 On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:  
  
 - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this  
 argument is set to `False`.  
 - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if  
 this argument is set to `True`.  
 """  
  
 def __init__(  
 self,  
 dataset: IterableDataset,  
 batch_size: int = 1,  
 drop_last: bool = False,  
 num_processes: int = 1,  
 process_index: int = 0,  
 split_batches: bool = False,  
 ):  
 if split_batches and batch_size > 1 and batch_size % num_processes != 0:  
 raise ValueError(  
 f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "  
 f"needs to be a round multiple of the number of processes ({num_processes})."  
 )  
 self.dataset = dataset # iterable of single item, not bached items  
 self.batch_size = batch_size # batch size of per shard  
 self.drop_last = drop_last  
 self.num_processes = num_processes  
 self.process_index = process_index  
 self.split_batches = split_batches  
  
 def set_epoch(self, epoch): # change seed with epoch  
 self.epoch = epoch  
 if hasattr(self.dataset, "set_epoch"):  
 self.dataset.set_epoch(epoch)  
  
 def __len__(self):  
 # We will just raise the downstream error if the underlying dataset is not sized  
 # len(self.dataset) = total items number of dataset  
 if self.drop_last:  
 # number of batches = len(self.dataset) // (self.batch_size * self.num_processes)  
 # number of examples per shard = `number of batches` * self.batch_size  
 return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size  
 else:  
 return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size  
  
 def __iter__(self):  
 if (  
 not hasattr(self.dataset, "set_epoch")  
 and hasattr(self.dataset, "generator")  
 and isinstance(self.dataset.generator, torch.Generator)  
 ):  
 self.dataset.generator.manual_seed(self.epoch)  
 real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes) # real batch size  
 process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size # batch sizer per process  
 process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size) # index of cur process  
  
 first_batch = None  
 current_batch = []  
 for element in self.dataset: # yeild single item from self.dataset  
 current_batch.append(element)  
 # Wait to have a full batch before yielding elements.  
 if len(current_batch) == real_batch_size: # get real batches  
 for i in process_slice:   
 # yeid batch items of cur process one by one  
 yield current_batch[i]  
 if first_batch is None:  
 first_batch = current_batch.copy()  
 current_batch = []  
  
 # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.  
 if not self.drop_last and len(current_batch) > 0:  
 if first_batch is None:  
 first_batch = current_batch.copy()  
 while len(current_batch) < real_batch_size:  
 current_batch += first_batch  
 for i in process_slice:  
 yield current_batch[i]

参考文献

0
0
0
0
评论
未登录
暂无评论