#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# @date: Wed May 8 19:42:39 CEST 2013
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shutil
import os
[docs]class Database():
"""Wrapper class for the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/).
"""
def __init__(self, data_dir = None):
"""Creates the database. The data_dir argument should be the path to the directory
containing the four binary files available from http://yann.lecun.com/exdb/mnist/"""
# initialize members
import os
self.m_labels = set(range(0,10))
self.m_groups = ('train', 'test')
self.m_mnist_filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
self.m_tmp_dir = None
# check if the data is available in the given directory (or if not given, in the default directory)
if not self._db_is_installed(data_dir):
self.m_data_dir = self._create_tmp_dir_and_download(data_dir)
if data_dir is None:
# if we create a temporary directory, mark it to be deleted at the end
self.m_tmp_dir = self.m_data_dir
elif data_dir is not None:
self.m_data_dir = data_dir
else:
from pkg_resources import resource_filename
self.m_data_dir = os.path.dirname(resource_filename(__name__, 'query.py'))
self.m_train_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[0])
self.m_train_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[1])
self.m_test_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[2])
self.m_test_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[3])
def __del__(self):
try:
if self.m_tmp_dir:
shutil.rmtree(self.m_tmp_dir) # delete directory
except OSError as e:
if e.errno != 2: # code 2 - no such file or directory
raise("bob.db.mnist: Error while erasing temporarily downloaded data files")
def _db_is_installed(self, directory = None):
from pkg_resources import resource_filename
import os
if directory is None:
db_files = [resource_filename(__name__, k) for k in self.m_mnist_filenames]
else:
db_files = [os.path.join(directory, k) for k in self.m_mnist_filenames]
for f in db_files:
if not os.path.exists(f):
return False
return True
def _create_tmp_dir_and_download(self, directory=None):
import tempfile, sys
if directory is None:
directory = tempfile.mkdtemp(prefix='mnist_db')
elif not os.path.exists(directory):
os.makedirs(directory)
print("Downloading the mnist database from http://yann.lecun.com/exdb/mnist/ ...")
for f in self.m_mnist_filenames:
tmp_file = os.path.join(directory, f)
url = 'http://yann.lecun.com/exdb/mnist/'+f
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with open(tmp_file, 'wb') as out_file:
response = urlopen(url)
out_file.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response:
with open(tmp_file, 'wb') as out_file:
copyfileobj(response, out_file)
return directory
def _read_labels(self, fname):
"""Reads the labels from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
def _read_images(self, fname):
"""Reads the images from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
def _check_parameters_for_validity(self, parameters, parameter_description, valid_parameters, default_parameters = None):
"""Checks the given parameters for validity, i.e., if they are contained in the set of valid parameters.
It also assures that the parameters form a tuple or a list.
If parameters is 'None' or empty, the default_parameters will be returned (if default_parameters is omitted, all valid_parameters are returned).
This function will return a tuple or list of parameters, or raise a ValueError.
Keyword parameters:
parameters
The parameters to be checked.
Might be a string, a list/tuple of strings, or None.
parameter_description
A short description of the parameter.
This will be used to raise an exception in case the parameter is not valid.
valid_parameters
A list/tuple of valid values for the parameters.
default_parameters
The list/tuple of default parameters that will be returned in case parameters is None or empty.
If omitted, all valid_parameters are used.
"""
if parameters is None:
# parameters are not specified, i.e., 'None' or empty lists
parameters = default_parameters if default_parameters is not None else valid_parameters
if not isinstance(parameters, (list, tuple, set)):
# parameter is just a single element, not a tuple or list -> transform it into a tuple
parameters = (parameters,)
# perform the checks
for parameter in parameters:
if parameter not in valid_parameters:
raise ValueError("Invalid %s '%s'. Valid values are %s, or lists/tuples of those" % (parameter_description, parameter, valid_parameters))
# check passed, now return the list/tuple of parameters
return parameters
[docs] def labels(self):
"""Returns the vector of labels
"""
return self.m_labels
[docs] def groups(self):
"""Returns the vector of groups
"""
return self.m_groups
[docs] def data(self, groups=None, labels=None):
"""Loads the MNIST samples and labels and returns them in NumPy arrays
Keyword Parameters:
groups
One of the groups 'train' or 'test' or a list with both of them (which is the default).
labels
A subset of the labels (digits 0 to 9) (everything is the default).
Returns: A tuple composed of images and labels as 2D numpy arrays considering
all the filtering criteria and organized as follows:
images
A 2D numpy.ndarray with as many rows as examples in the dataset, as many
columns as pixels (actually, there are 28x28 = 784 rows). The pixels of each
image are unrolled in C-scan order (i.e., first row 0, then row 1, etc.).
labels
A 1D numpy.ndarray with as many elements as examples in the dataset.
"""
# check if groups set are valid
groups = self._check_parameters_for_validity(groups, "group", self.m_groups)
vlabels = self._check_parameters_for_validity(labels, "label", self.m_labels)
# Reads data from the groups
import numpy
if 'train' in groups and 'test' in groups:
images1 = self._read_images(self.m_train_fname_images)
labels1 = self._read_labels(self.m_train_fname_labels)
images2 = self._read_images(self.m_test_fname_images)
labels2 = self._read_labels(self.m_test_fname_labels)
images = numpy.vstack([images1,images2])
labels = numpy.hstack([labels1,labels2])
elif 'train' in groups:
images = self._read_images(self.m_train_fname_images)
labels = self._read_labels(self.m_train_fname_labels)
elif 'test' in groups:
images = self._read_images(self.m_test_fname_images)
labels = self._read_labels(self.m_test_fname_labels)
else:
images = numpy.ndarray(shape=(0,784), dtype=numpy.uint8)
labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8)
# List of indices for which the labels are in the list of requested labels
indices = numpy.where(numpy.array([v in vlabels for v in labels]))[0]
images = images[indices,:]
labels = labels[indices]
return images, labels