Source code for nnabla.utils.data_iterator

# Copyright 2017,2018,2019,2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'''
Data Iterator is module for getting data from user defined source with following features.

Detailed design document is :doc:`/doc/designs/data_iterator`.

'''

import atexit
import threading
import queue

import numpy
import six
from nnabla.logger import logger

from .data_source import DataSourceWithFileCache
from .data_source import DataSourceWithMemoryCache
from .data_source import SlicedDataSource
from .data_source_implements import CacheDataSource
from .data_source_implements import ConcatDataSource
from .data_source_implements import CsvDataSource
from .data_source_implements import SimpleDataSource


[docs]class DataIterator(object): '''DataIterator Collect data from `data_source` and yields bunch of data. Args: data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`): Instance of DataSource class witch provides data for this class. batch_size (int): Size of data unit. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as a argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as a argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. ''' def __init__(self, data_source, batch_size, rng=None, use_thread=True, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): logger.info('Using DataIterator') if rng is None: rng = numpy.random.RandomState(313) self._stop_exhausted = stop_exhausted self._rng = rng self._shape = None # Only use with padding self._data_position = 0 # Only use with padding self._data_source = data_source # place holder for shuffle is enabled or not at starting time. self._shuffle = self._data_source.shuffle self._variables = data_source.variables self._num_of_variables = len(data_source.variables) self._batch_size = batch_size self._epoch_end_callbacks = list(epoch_end_callbacks) self._epoch_begin_callbacks = list(epoch_begin_callbacks) self._size = data_source.size self._n_reset = 0 self._queue = queue.Queue() self._reset() self._current_epoch = 0 self._use_thread = use_thread if self._use_thread: self._next_thread = threading.Thread(target=self._next) self._next_thread.start() self._closed = False atexit.register(self.close) def __iter__(self): return self def __next__(self): return self.next() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def close(self): if not self._closed: if six.PY3: atexit.unregister(self.close) if self._use_thread: self._next_thread.join() self._data_source.close() self._closed = True @property def epoch(self): '''epoch The number of times :py:meth:`position` returns to zero. Returns: int: epoch ''' return self._current_epoch @property def position(self): '''position Data position in current epoch. Returns: int: Data position ''' return self._data_source.position @property def size(self): '''size Data size that DataIterator will generate. This is the largest integer multiple of batch_size not exceeding :py:meth:`self._data_source.size`. Returns: int: Data size ''' return self._size @property def variables(self): '''variables Variable names of the data. Returns: tuple: tuple of Variable names ''' return self._variables @property def batch_size(self): '''batch_size Number of training samples that :py:meth:`next()` returns. Returns: int: Number of training samples. ''' return self._batch_size def _reset(self): self._data_source.reset() def _next(self): n_reset = 0 data = [[] for x in self._variables] if self._data_source.position + self._batch_size > self._size: if self._stop_exhausted: self._queue.put((None, 0)) return for b in range(self._batch_size): d = self._data_source.next() if d is None: self._queue.put((None, 0)) return if self._data_source.position >= self._size and not self._stop_exhausted: n_reset += 1 self._reset() for i, v in enumerate(self._variables): data[i].append(d[i]) if n_reset != 0: self._data_source.apply_order() self._queue.put((tuple([numpy.array(x) for x in data]), n_reset))
[docs] def next(self): '''next It generates tuple of data. For example, if :py:meth:`self._variables == ('x', 'y')` This method returns :py:meth:` ( [[X] * batch_size], [[Y] * batch_size] )` Returns: tuple: tuple of data for mini-batch in numpy.ndarray. ''' if not self._use_thread: self._next() data, n_reset = self._queue.get() self._queue.task_done() if self._use_thread: self._next_thread.join() if data is None: if self._stop_exhausted and self._data_source.position + self._batch_size >= self._size: raise StopIteration if self._use_thread: logger.log(99, 'next() got None retrying...') self._next_thread = threading.Thread(target=self._next) self._next_thread.start() data, n_reset = self._queue.get() self._queue.task_done() self._next_thread.join() if self._use_thread: self._next_thread = threading.Thread(target=self._next) self._next_thread.start() for _ in range(n_reset): if self._current_epoch >= 0: self._callback_epoch_end() self._current_epoch += 1 self._callback_epoch_begin() return data
[docs] def slice(self, rng, num_of_slices=None, slice_pos=None, slice_start=None, slice_end=None, cache_dir=None, use_cache=False, drop_last=False): ''' Slices the data iterator so that newly generated data iterator has access to limited portion of the original data. Args: rng (numpy.random.RandomState): Random generator for Initializer. num_of_slices(int): Total number of slices to be made. Muts be used together with `slice_pos`. slice_pos(int): Position of the slice to be assigned to the new data iterator. Must be used together with `num_of_slices`. slice_start(int): Starting position of the range to be sliced into new data iterator. Must be used together with `slice_end`. slice_end(int) : End position of the range to be sliced into new data iterator. Must be used together with `slice_start`. cache_dir(str) : Directory to save cache files. if cache_dir is None and use_cache is True, will used memory cache. use_cache(bool): Whether use cache for data_source. drop_last(bool): If it is True, the samples if the number of samples cannot be evenly divisible are dropped, If it is False, the samples are duplicated so that it is evenly divisible. Example: .. code-block:: python from nnabla.utils.data_iterator import data_iterator_simple import numpy as np def load_func1(index): d = np.ones((2, 2)) * index return d di = data_iterator_simple(load_func1, 1000, batch_size=3) di_s1 = di.slice(None, num_of_slices=10, slice_pos=0) di_s2 = di.slice(None, num_of_slices=10, slice_pos=1) di_s3 = di.slice(None, slice_start=100, slice_end=200) di_s4 = di.slice(None, slice_start=300, slice_end=400) ''' def get_slice_start_end(size, n_slices, rank): """ This algorithm tends to evenly assign the remained to each slice block with some integer tricky. """ if drop_last: size_rank = size // n_slices slice_start = size_rank * rank slice_end = slice_start + size_rank else: size_rank = (size + n_slices - 1) // n_slices slice_start = (size_rank * rank) % size slice_end = slice_start + size_rank if slice_end > size: offset = slice_end - size slice_end -= offset slice_start -= offset return slice_start, slice_end if num_of_slices is not None and slice_pos is not None and slice_start is None and slice_end is None: slice_start, slice_end = get_slice_start_end( self._size, num_of_slices, slice_pos) elif num_of_slices is None and slice_pos is None and slice_start is not None and slice_end is not None: pass else: logger.critical( 'You must specify position(num_of_slice and slice_pos) or range(slice_start and slice_end).') return None if cache_dir is None: ds = self._data_source while '_data_source' in dir(ds): if '_cache_dir' in dir(ds): cache_dir = ds._cache_dir ds = ds._data_source if use_cache: if cache_dir is None: return DataIterator( DataSourceWithMemoryCache( SlicedDataSource( self._data_source, self._data_source.shuffle, slice_start=slice_start, slice_end=slice_end), shuffle=self._shuffle, rng=rng), self._batch_size) else: return DataIterator( DataSourceWithMemoryCache( DataSourceWithFileCache( SlicedDataSource( self._data_source, self._data_source.shuffle, slice_start=slice_start, slice_end=slice_end), cache_dir=cache_dir, cache_file_name_prefix='cache_sliced_{:08d}_{:08d}'.format( slice_start, slice_end), shuffle=self._shuffle, rng=rng), shuffle=self._shuffle, rng=rng), self._batch_size) else: return DataIterator( SlicedDataSource( self._data_source, self._data_source.shuffle, slice_start=slice_start, slice_end=slice_end), self._batch_size)
def _callback_epoch_end(self): for callback in self._epoch_end_callbacks: callback(self.epoch) def _callback_epoch_begin(self): for callback in self._epoch_begin_callbacks: callback(self.epoch)
[docs] def register_epoch_end_callback(self, callback): """Register epoch end callback. Args: callback (function): A function takes an epoch index as an argument. """ self._epoch_end_callbacks.append(callback)
[docs] def register_epoch_begin_callback(self, callback): """Register epoch begin callback. Args: callback (function): A function takes an epoch index as an argument. """ self._epoch_begin_callbacks.append(callback)
[docs]def data_iterator(data_source, batch_size, rng=None, use_thread=True, with_memory_cache=True, with_file_cache=False, cache_dir=None, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): '''data_iterator Helper method to use :py:class:`DataSource <nnabla.utils.data_source.DataSource>`. You can use :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>` with your own :py:class:`DataSource <nnabla.utils.data_source.DataSource>` for easy implementation of data sources. For example, .. code-block:: python ds = YourOwnImplementationOfDataSource() batch = data_iterator(ds, batch_size) Args: data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`): Instance of DataSource class which provides data. batch_size (int): Batch size. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. with_memory_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache` to wrap ``data_source``. It is a good idea to set this as true unless data_source provides on-memory data. Default value is True. with_file_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithFileCache` to wrap ``data_source``. If ``data_source`` is slow, enabling this option a is good idea. Default value is False. cache_dir (str): Location of file_cache. If this value is None, :py:class:`.data_source.DataSourceWithFileCache` creates file caches implicitly on temporary directory and erases them all when data_iterator is finished. Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache. Default is None. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. Returns: :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`: Instance of DataIterator. ''' if with_file_cache: ds = DataSourceWithFileCache(data_source=data_source, cache_dir=cache_dir, shuffle=data_source.shuffle, rng=rng) if with_memory_cache: ds = DataSourceWithMemoryCache(ds, shuffle=ds.shuffle, rng=rng) return DataIterator(ds, batch_size, use_thread=use_thread, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted) else: if with_memory_cache: data_source = DataSourceWithMemoryCache(data_source, shuffle=data_source.shuffle, rng=rng) return DataIterator(data_source, batch_size, use_thread=use_thread, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted)
[docs]def data_iterator_simple(load_func, num_examples, batch_size, shuffle=False, rng=None, use_thread=True, with_memory_cache=False, with_file_cache=False, cache_dir=None, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): """A generator that ``yield`` s minibatch data as a tuple, as defined in ``load_func`` . It can unlimitedly yield minibatches at your request, queried from the provided data. Args: load_func (function): Takes a single argument `i`, an index of an example in your dataset to be loaded, and returns a tuple of data. Every call by any index `i` must return a tuple of arrays with the same shape. num_examples (int): Number of examples in your dataset. Random sequence of indexes is generated according to this number. batch_size (int): Size of data unit. shuffle (bool): Indicates whether the dataset is shuffled or not. Default value is False. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. with_memory_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache` to wrap ``data_source``. It is a good idea to set this as true unless data_source provides on-memory data. Default value is False. with_file_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithFileCache` to wrap ``data_source``. If ``data_source`` is slow, enabling this option a is good idea. Default value is False. cache_dir (str): Location of file_cache. If this value is None, :py:class:`.data_source.DataSourceWithFileCache` creates file caches implicitly on temporary directory and erases them all when data_iterator is finished. Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache. Default is None. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. Returns: :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`: Instance of DataIterator. Here is an example of `load_func` which returns an image and a label of a classification dataset. .. code-block:: python import numpy as np from nnabla.utils.image_utils import imread image_paths = load_image_paths() labels = load_labels() def my_load_func(i): ''' Returns: image: c x h x w array label: 0-shape array ''' img = imread(image_paths[i]).astype('float32') return np.rollaxis(img, 2), np.array(labels[i]) """ return data_iterator(SimpleDataSource(load_func, num_examples, shuffle=shuffle, rng=rng), use_thread=use_thread, batch_size=batch_size, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, cache_dir=cache_dir, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted)
[docs]def data_iterator_csv_dataset(uri, batch_size, shuffle=False, rng=None, use_thread=True, normalize=True, with_memory_cache=True, with_file_cache=True, cache_dir=None, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): '''data_iterator_csv_dataset Get data directly from a dataset provided as a CSV file. You can read files located on the local file system, http(s) servers or Amazon AWS S3 storage. For example, .. code-block:: python batch = data_iterator_csv_dataset('CSV_FILE.csv', batch_size, shuffle=True) Args: uri (str): Location of dataset CSV file. batch_size (int): Size of data unit. shuffle (bool): Indicates whether the dataset is shuffled or not. Default value is False. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. normalize (bool): If True, each sample in the data gets normalized by a factor of 255. Default is True. with_memory_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache` to wrap ``data_source``. It is a good idea to set this as true unless data_source provides on-memory data. Default value is True. with_file_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithFileCache` to wrap ``data_source``. If ``data_source`` is slow, enabling this option a is good idea. Default value is False. cache_dir (str): Location of file_cache. If this value is None, :py:class:`.data_source.DataSourceWithFileCache` creates file caches implicitly on temporary directory and erases them all when data_iterator is finished. Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache. Default is None. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. Returns: :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`: Instance of DataIterator ''' ds = CsvDataSource(uri, shuffle=shuffle, rng=rng, normalize=normalize) return data_iterator(ds, use_thread=use_thread, batch_size=batch_size, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, cache_dir=cache_dir, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted)
[docs]def data_iterator_cache(uri, batch_size, shuffle=False, rng=None, use_thread=True, normalize=True, with_memory_cache=True, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): '''data_iterator_cache Get data from the cache directory. Cache files are read from the local file system. For example, .. code-block:: python batch = data_iterator_cache('CACHE_DIR', batch_size, shuffle=True) Args: uri (str): Location of directory with cache files. batch_size (int): Size of data unit. shuffle (bool): Indicates whether the dataset is shuffled or not. Default value is False. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. normalize (bool): If True, each sample in the data gets normalized by a factor of 255. Default is True. with_memory_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache` to wrap ``data_source``. It is a good idea to set this as true unless data_source provides on-memory data. Default value is True. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. Returns: :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`: Instance of DataIterator ''' ds = CacheDataSource(uri, shuffle=shuffle, rng=rng, normalize=normalize) return data_iterator(ds, use_thread=use_thread, batch_size=batch_size, with_memory_cache=with_memory_cache, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted)
[docs]def data_iterator_concat_datasets(data_source_list, batch_size, shuffle=False, rng=None, use_thread=True, with_memory_cache=True, with_file_cache=False, cache_dir=None, epoch_begin_callbacks=[], epoch_end_callbacks=[], stop_exhausted=False): '''data_iterator_concat_datasets Get data from multiple datasets. For example, .. code-block:: python batch = data_iterator_concat_datasets([DataSource0, DataSource1, ...], batch_size) Args: data_source_list (list of DataSource): list of datasets. batch_size (int): Size of data unit. shuffle (bool): Indicates whether the dataset is shuffled or not. Default value is False. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. use_thread (bool): If ``use_thread`` is set to True, iterator will use another thread to fetch data. If ``use_thread`` is set to False, iterator will use current thread to fetch data. with_memory_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache` to wrap ``data_source``. It is a good idea to set this as true unless data_source provides on-memory data. Default value is True. with_file_cache (bool): If ``True``, use :py:class:`.data_source.DataSourceWithFileCache` to wrap ``data_source``. If ``data_source`` is slow, enabling this option a is good idea. Default value is False. cache_dir (str): Location of file_cache. If this value is None, :py:class:`.data_source.DataSourceWithFileCache` creates file caches implicitly on temporary directory and erases them all when data_iterator is finished. Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache. Default is None. epoch_begin_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the beginning of an epoch. epoch_end_callbacks (list of functions): An item is a function which takes an epoch index as an argument. These are called at the end of an epoch. stop_exhausted (bool): If ``stop_exhausted`` is set to False, iterator will be reset so that iteration can be continued. If ``stop_exhausted`` is set to True, iterator will raise StopIteration to stop the loop. Returns: :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`: Instance of DataIterator ''' ds = ConcatDataSource(data_source_list, shuffle=shuffle, rng=rng) return data_iterator(ds, use_thread=use_thread, batch_size=batch_size, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, epoch_begin_callbacks=epoch_begin_callbacks, epoch_end_callbacks=epoch_end_callbacks, stop_exhausted=stop_exhausted)