Source code for nnabla.utils.data_source

# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# TODO temporary work around to suppress FutureWarning message.
import warnings
warnings.simplefilter('ignore', category=FutureWarning)
import h5py

from collections import OrderedDict
from contextlib import closing
from contextlib import closing
from multiprocessing import Queue
from multiprocessing.pool import ThreadPool
from shutil import rmtree
import abc
import atexit

# TODO temporary work around to suppress FutureWarning message.
import warnings
warnings.simplefilter('ignore', category=FutureWarning)
import h5py

import csv
import numpy
import os
import six
import tempfile
import threading

from nnabla.config import nnabla_config
from nnabla.logger import logger
from nnabla.utils.progress import progress
from nnabla.utils.communicator_util import single_or_rankzero
from .data_source_loader import FileReader


[docs]class DataSource(object): ''' This class contains various properties and methods for the data source, which are utilized by py:class:`DataIterator`. Args: shuffle (bool): Indicates whether the dataset is shuffled or not. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. ''' __metaclass__ = abc.ABCMeta @abc.abstractmethod def _get_data(self, position): pass def __init__(self, shuffle=False, rng=None): ''' Init method for DataSource ''' logger.info('DataSource with shuffle({})'.format(shuffle)) self._rng = rng if rng is None: self._rng = numpy.random.RandomState(313) self._variables = None self._generation = -1 self._shuffle = shuffle self._position = 0 self._size = 0 self._closed = False self._order = None self._original_order = None self._original_source_uri = None atexit.register(self.close) def __next__(self): return self.next() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def close(self): if not self._closed: if six.PY3: atexit.unregister(self.close) self._closed = True @property def variables(self): '''variables Variable names of the data. Returns: tuple: tuple of Variable names ''' return self._variables def next(self): data = self._get_data(self._position) self._position += 1 return data @property def position(self): '''position Data position in current epoch. Returns: int: Data position ''' return self._position @property def size(self): return self._size @property def shuffle(self): ''' Whether dataset is shuffled or not. Returns: bool: whether dataset is shuffled. ''' return self._shuffle @shuffle.setter def shuffle(self, value): self._shuffle = value @abc.abstractmethod def reset(self): self._position = 0
class DataSourceWithFileCacheError(Exception): pass
[docs]class DataSourceWithFileCache(DataSource): ''' This class contains properties and methods for data source that can be read from cache files, which are utilized by data iterator. Args: data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`): Instance of DataSource class which provides data. cache_dir (str): Location of file_cache. If this value is None, :py:class:`.data_source.DataSourceWithFileCache` creates file caches implicitly on temporary directory and erases them all when data_iterator is finished. Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache. Default is None. cache_file_name_prefix (str): Beginning of the filenames of cache files. Default is 'cache'. shuffle (bool): Indicates whether the dataset is shuffled or not. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. ''' def _save_cache_to_file(self): ''' Store cache data into file. Data will be stored as hdf5 format, placed at config.. Cache file name format is "cache_START_END.h5" ''' if self._cache_dir is None: raise DataSourceWithFileCacheError( 'Use this class with "with statement" if you don\'t specify cache dir.') cache_data = OrderedDict() def get_data(args): pos = args[0] q = args[1] retry = 1 while True: if retry > 10: logger.log( 99, '_get_current_data() retry count over give up.') raise d = self._data_source._get_data(pos) if d is not None: break logger.log(99, '_get_data() fails. retrying count {}/10.'.format( retry)) retry += 1 q.put((pos, d)) q = Queue() with closing(ThreadPool(processes=self._num_of_threads)) as pool: pool.map(get_data, [(pos, q) for pos in self._cache_positions]) while len(cache_data) < len(self._cache_positions): index, data = q.get() cache_data[index] = data start_position = self.position - len(cache_data) + 1 end_position = self.position cache_filename = os.path.join( self._cache_dir, '{}_{:08d}_{:08d}{}'.format(self._cache_file_name_prefix, start_position, end_position, self._cache_file_format)) data = OrderedDict([(n, []) for n in self._data_source.variables]) for pos in sorted(cache_data): cd = cache_data[pos] for i, n in enumerate(self._data_source.variables): if isinstance(cd[i], numpy.ndarray): d = cd[i] else: d = numpy.array(cd[i]).astype(numpy.float32) data[n].append(d) logger.info('Creating cache file {}'.format(cache_filename)) try: if self._cache_file_format == ".h5": h5 = h5py.File(cache_filename, 'w') for k, v in data.items(): h5.create_dataset(k, data=v) h5.close() else: retry_count = 1 is_create_cache_imcomplete = True while is_create_cache_imcomplete: try: with open(cache_filename, 'wb') as f: for v in data.values(): numpy.save(f, v) is_create_cache_imcomplete = False except OSError: retry_count += 1 if retry_count > 10: raise logger.info( 'Creating cache retry {}/10'.format(retry_count)) except: logger.critical( 'An error occurred while creating cache file from dataset.') for k, v in data.items(): size = v[0].shape for d in v: if size != d.shape: logger.critical('The sizes of data "{}" are not the same. ({} != {})'.format( k, size, d.shape)) raise self._cache_file_names.append(cache_filename) self._cache_file_order.append(len(self._cache_file_order)) self._cache_file_data_orders.append(list(range(len(cache_data)))) self._cache_positions = [] def _store_data_to_cache_buffer(self, position): self._cache_positions.append(position) if position == self._total_cached_size: self._total_cached_size += 1 if len(self._cache_positions) >= self._cache_size or self._total_cached_size >= self.size: self._save_cache_to_file() def _get_data_from_cache_file(self, position): cache_file_index = self._cache_file_positions[position] cache_data_position = \ self._cache_file_data_orders[cache_file_index][position - self._cache_file_start_positions[cache_file_index]] if self._current_cache_file_index != cache_file_index: self._current_cache_file_index = cache_file_index if self._cache_file_format == '.npy': self._current_cache_data = {} if not os.path.exists(self._cache_file_names[cache_file_index]): return None with open(self._cache_file_names[cache_file_index], 'rb') as f: for v in self._variables: self._current_cache_data[v] = numpy.load(f) else: h5 = h5py.File(self._cache_file_names[cache_file_index], 'r') self._current_cache_data = {} for k, v in h5.items(): self._current_cache_data[k] = v.value h5.close() d = [self._current_cache_data[v][cache_data_position] for v in self.variables] return d def _get_data(self, position): with self._thread_lock: self._position = position return self._get_data_from_cache_file(position) def _create_cache(self): # Save all data into cache file(s). self._cache_positions = [] self._position = 0 percent = 0 if single_or_rankzero(): progress(None) while self._position < self._data_source._size: if single_or_rankzero(): progress('Create cache', self._position * 1.0 / self._data_source._size) self._store_data_to_cache_buffer(self._position) self._position += 1 if len(self._cache_positions) > 0: self._save_cache_to_file() if single_or_rankzero(): progress(None) # Adjust data size into reseted position. In most case it means # multiple of bunch(mini-batch) size. num_of_cache_files = int(numpy.ceil( float(self._data_source._size) / self._cache_size)) self._cache_file_order = self._cache_file_order[ 0:num_of_cache_files] self._cache_file_data_orders = self._cache_file_data_orders[ 0:num_of_cache_files] if self._data_source._size % self._cache_size != 0: self._cache_file_data_orders[num_of_cache_files - 1] = self._cache_file_data_orders[ num_of_cache_files - 1][0:self._data_source._size % self._cache_size] # Create Index index_filename = os.path.join(self._cache_dir, "cache_index.csv") with open(index_filename, 'w') as f: writer = csv.writer(f, lineterminator='\n') for fn, orders in zip(self._cache_file_names, self._cache_file_data_orders): writer.writerow((os.path.basename(fn), len(orders))) # Create Info if self._cache_file_format == ".npy": info_filename = os.path.join(self._cache_dir, "cache_info.csv") with open(info_filename, 'w') as f: writer = csv.writer(f, lineterminator='\n') for variable in self._variables: writer.writerow((variable, )) # Create original.csv if self._data_source._original_source_uri is not None: fr = FileReader(self._data_source._original_source_uri) with fr.open() as f: csv_lines = [x.decode('utf-8') for x in f.readlines()] with open(os.path.join(self._cache_dir, "original.csv"), 'w') as o: for l in csv_lines: o.write(l) # Create order.csv if self._data_source._order is not None and \ self._data_source._original_order is not None: with open(os.path.join(self._cache_dir, "order.csv"), 'w') as o: writer = csv.writer(o, lineterminator='\n') for orders in zip(self._data_source._original_order, self._data_source._order): writer.writerow(list(orders)) def _create_cache_file_position_table(self): # Create cached data position table. pos = 0 self._cache_file_start_positions = list( range(len(self._cache_file_order))) self._order = list(range(len(self._order))) self._cache_file_positions = list(range(len(self._order))) count = 0 for i, cache_file_pos in enumerate(self._cache_file_order): self._cache_file_start_positions[cache_file_pos] = pos pos += len(self._cache_file_data_orders[cache_file_pos]) for j in self._cache_file_data_orders[cache_file_pos]: p = j + (cache_file_pos * self._cache_size) self._order[count] = p self._cache_file_positions[count] = cache_file_pos count += 1 def __init__(self, data_source, cache_dir=None, cache_file_name_prefix='cache', shuffle=False, rng=None): self._tempdir_created = False logger.info('Using DataSourceWithFileCache') super(DataSourceWithFileCache, self).__init__(shuffle=shuffle, rng=rng) self._cache_file_name_prefix = cache_file_name_prefix self._cache_dir = cache_dir logger.info('Cache Directory is {}'.format(self._cache_dir)) self._cache_size = int(nnabla_config.get( 'DATA_ITERATOR', 'data_source_file_cache_size')) logger.info('Cache size is {}'.format(self._cache_size)) self._num_of_threads = int(nnabla_config.get( 'DATA_ITERATOR', 'data_source_file_cache_num_of_threads')) logger.info('Num of thread is {}'.format(self._num_of_threads)) self._cache_file_format = nnabla_config.get( 'DATA_ITERATOR', 'cache_file_format') logger.info('Cache file format is {}'.format(self._cache_file_format)) self._thread_lock = threading.Lock() self._size = data_source._size self._variables = data_source.variables self._data_source = data_source self._generation = -1 self._cache_positions = [] self._total_cached_size = 0 self._cache_file_names = [] self._cache_file_order = [] self._cache_file_start_positions = [] self._cache_file_data_orders = [] self._current_cache_file_index = -1 self._current_cache_data = None self.shuffle = shuffle self._original_order = list(range(self._size)) self._order = list(range(self._size)) # __enter__ if self._cache_dir is None: self._tempdir_created = True if nnabla_config.get('DATA_ITERATOR', 'data_source_file_cache_location') != '': self._cache_dir = tempfile.mkdtemp(dir=nnabla_config.get( 'DATA_ITERATOR', 'data_source_file_cache_location')) else: self._cache_dir = tempfile.mkdtemp() logger.info( 'Tempdir for cache {} created.'.format(self._cache_dir)) self._closed = False atexit.register(self.close) self._create_cache() self._create_cache_file_position_table() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def close(self): if not self._closed: if six.PY3: atexit.unregister(self.close) if self._tempdir_created: # logger.info('Remove created tempdir {}'.format(self._cache_dir)) rmtree(self._cache_dir, ignore_errors=True) self._data_source.close() self._closed = True def reset(self): with self._thread_lock: if self._shuffle: self._cache_file_order = list( self._rng.permutation(self._cache_file_order)) for i in range(len(self._cache_file_data_orders)): self._cache_file_data_orders[i] = list( self._rng.permutation(self._cache_file_data_orders[i])) self._order = [] for i in self._cache_file_order: self._order += self._cache_file_data_orders[i] self._create_cache_file_position_table() self._data_source.reset() self._position = 0 self._generation += 1
[docs]class DataSourceWithMemoryCache(DataSource): ''' This class contains properties and methods for data source that can be read from memory cache, which is utilized by data iterator. Args: data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`): Instance of DataSource class which provides data. shuffle (bool): Indicates whether the dataset is shuffled or not. rng (None or :obj:`numpy.random.RandomState`): Numpy random number generator. ''' def _get_data_func(self, position): return self._data_source._get_data(position) def _get_data(self, position): if self._on_memory: if self._order[position] < len(self._cache): data = self._cache[self._order[position]] else: data = self._get_data_func(position) self._cache.append(data) else: data = self._data_source._get_data(position) self._position = position return data def __init__(self, data_source, shuffle=False, rng=None): logger.info('Using DataSourceWithMemoryCache') super(DataSourceWithMemoryCache, self).__init__( shuffle=shuffle, rng=rng) self._buffer_max_size = int(nnabla_config.get( 'DATA_ITERATOR', 'data_source_buffer_max_size')) self._size = data_source._size self._variables = data_source.variables self._data_source = data_source self._order = list(range(self._size)) self._on_memory = False self._cache = [] data = self._get_data_func(0) self._data_size = 0 for d in data: if isinstance(d, list): d = numpy.array(d, dtype=numpy.float32) self._data_size += d.size * d.itemsize total_size = self._data_size * self._size if total_size < self._buffer_max_size: logger.info('On-memory') self._on_memory = True self._generation = -1 self._closed = False atexit.register(self.close) def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def close(self): if not self._closed: if six.PY3: atexit.unregister(self.close) self._data_source.close() self._closed = True def reset(self): if self._on_memory: self._generation += 1 if self._shuffle and self._generation > 0: self._order = list(self._rng.permutation(self._size)) else: self._order = list(range(self._size)) if self._position == 0: self._generation = -1 else: self._data_source._position = self._position self._data_source.reset() else: self._data_source.reset() self._generation = self._data_source._generation self._position = self._data_source._position super(DataSourceWithMemoryCache, self).reset()
class SlicedDataSource(DataSource): ''' Provides sliced data source. Args: data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`): Instance of DataSource class which provides data. ''' def __init__(self, data_source, shuffle=False, rng=None, slice_start=None, slice_end=None): logger.info('Using SlicedDataSource') super(SlicedDataSource, self).__init__(shuffle=shuffle, rng=rng) self._data_source = data_source self._variables = data_source._variables[:] self._slice_start = slice_start self._slice_end = slice_end self._size = self._slice_end - self._slice_start self._generation = -1 self.reset() def reset(self): self._data_source.reset() self._data_source._position = self._slice_start self._generation += 1 self._position = 0 def _get_data(self, position): self._position = position data = self._data_source._get_data(self._slice_start + position) return data