diff options
Diffstat (limited to 'code/sunlab/common/data/dataset_iterator.py')
-rw-r--r-- | code/sunlab/common/data/dataset_iterator.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py new file mode 100644 index 0000000..7c91caa --- /dev/null +++ b/code/sunlab/common/data/dataset_iterator.py @@ -0,0 +1,34 @@ +class DatasetIterator: + """# Dataset Iterator + + Creates an iterator object on a dataset and labels""" + + def __init__(self, dataset, labels=None, batch_size=None): + """# Initialize the iterator with the dataset and labels + + - batch_size: How many to include in the iteration""" + self.dataset = dataset + self.labels = labels + self.current = 0 + self.batch_size = ( + batch_size if batch_size is not None else self.dataset.shape[0] + ) + + def __iter__(self): + """# Iterator Function""" + return self + + def __next__(self): + """# Next Iteration + + Slice the dataset and labels to provide""" + self.cur = self.current + self.current += 1 + if self.cur * self.batch_size < self.dataset.shape[0]: + iterator_slice = slice( + self.cur * self.batch_size, (self.cur + 1) * self.batch_size + ) + if self.labels is None: + return self.dataset[iterator_slice, ...] + return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...] + raise StopIteration |